Spaces:
Paused
Paused
| #!/usr/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| https://zhuanlan.zhihu.com/p/627039860 | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget | |
| class NegativeSNRLoss(nn.Module): | |
| """ | |
| Signal-to-Noise Ratio | |
| """ | |
| def __init__(self, eps: float = 1e-8): | |
| super(NegativeSNRLoss, self).__init__() | |
| self.eps = eps | |
| def forward(self, denoise: torch.Tensor, clean: torch.Tensor): | |
| """ | |
| Compute the SI-SNR loss between the estimated signal and the target signal. | |
| :param denoise: The estimated signal (batch_size, signal_length) | |
| :param clean: The target signal (batch_size, signal_length) | |
| :return: The SI-SNR loss (batch_size,) | |
| """ | |
| if denoise.shape != clean.shape: | |
| raise AssertionError("Input signals must have the same shape") | |
| denoise = denoise - torch.mean(denoise, dim=-1, keepdim=True) | |
| clean = clean - torch.mean(clean, dim=-1, keepdim=True) | |
| noise = denoise - clean | |
| clean_power = torch.norm(clean, p=2, dim=-1) ** 2 | |
| noise_power = torch.norm(noise, p=2, dim=-1) ** 2 | |
| snr = 10 * torch.log10((clean_power + self.eps) / (noise_power + self.eps)) | |
| return -snr.mean() | |
| class NegativeSISNRLoss(nn.Module): | |
| """ | |
| Scale-Invariant Source-to-Noise Ratio | |
| https://arxiv.org/abs/2206.07293 | |
| """ | |
| def __init__(self, | |
| reduction: str = "mean", | |
| eps: float = 1e-8, | |
| ): | |
| super(NegativeSISNRLoss, self).__init__() | |
| self.reduction = reduction | |
| self.eps = eps | |
| def forward(self, denoise: torch.Tensor, clean: torch.Tensor): | |
| """ | |
| Compute the SI-SNR loss between the estimated signal and the target signal. | |
| :param denoise: The estimated signal (batch_size, signal_length) | |
| :param clean: The target signal (batch_size, signal_length) | |
| :return: The SI-SNR loss (batch_size,) | |
| """ | |
| if denoise.shape != clean.shape: | |
| raise AssertionError("Input signals must have the same shape") | |
| denoise = denoise - torch.mean(denoise, dim=-1, keepdim=True) | |
| clean = clean - torch.mean(clean, dim=-1, keepdim=True) | |
| s_target = torch.sum(denoise * clean, dim=-1, keepdim=True) * clean / (torch.norm(clean, p=2, dim=-1, keepdim=True) ** 2 + self.eps) | |
| e_noise = denoise - s_target | |
| batch_si_snr = 10 * torch.log10(torch.norm(s_target, p=2, dim=-1) ** 2 / (torch.norm(e_noise, p=2, dim=-1) ** 2 + self.eps) + self.eps) | |
| # si_snr shape: [batch_size,] | |
| if self.reduction == "mean": | |
| loss = torch.mean(batch_si_snr) | |
| elif self.reduction == "sum": | |
| loss = torch.sum(batch_si_snr) | |
| else: | |
| raise AssertionError | |
| return -loss | |
| class LocalSNRLoss(nn.Module): | |
| """ | |
| https://github.com/Rikorose/DeepFilterNet/blob/main/DeepFilterNet/df/modules.py#L816 | |
| """ | |
| def __init__(self, | |
| sample_rate: int = 8000, | |
| nfft: int = 512, | |
| win_size: int = 512, | |
| hop_size: int = 256, | |
| n_frame: int = 3, | |
| min_local_snr: int = -15, | |
| max_local_snr: int = 30, | |
| db: bool = True, | |
| factor: float = 1, | |
| reduction: str = "mean", | |
| eps: float = 1e-8, | |
| ): | |
| super(LocalSNRLoss, self).__init__() | |
| self.sample_rate = sample_rate | |
| self.nfft = nfft | |
| self.win_size = win_size | |
| self.hop_size = hop_size | |
| self.factor = factor | |
| self.reduction = reduction | |
| self.eps = eps | |
| self.lsnr_fn = LocalSnrTarget( | |
| sample_rate=sample_rate, | |
| nfft=nfft, | |
| win_size=win_size, | |
| hop_size=hop_size, | |
| n_frame=n_frame, | |
| min_local_snr=min_local_snr, | |
| max_local_snr=max_local_snr, | |
| db=db, | |
| ) | |
| self.window = nn.Parameter(torch.hann_window(win_size), requires_grad=False) | |
| def forward(self, lsnr: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor): | |
| if clean.shape != noisy.shape: | |
| raise AssertionError("Input signals must have the same shape") | |
| noise = noisy - clean | |
| stft_clean = torch.stft( | |
| clean, | |
| n_fft=self.nfft, | |
| win_length=self.win_size, | |
| hop_length=self.hop_size, | |
| window=self.window, | |
| center=self.center, | |
| pad_mode="reflect", | |
| normalized=False, | |
| return_complex=True | |
| ) | |
| stft_noise = torch.stft( | |
| noise, | |
| n_fft=self.nfft, | |
| win_length=self.win_size, | |
| hop_length=self.hop_size, | |
| window=self.window, | |
| center=self.center, | |
| pad_mode="reflect", | |
| normalized=False, | |
| return_complex=True | |
| ) | |
| # lsnr shape: [b, 1, t] | |
| lsnr = lsnr.squeeze(1) | |
| # lsnr shape: [b, t] | |
| lsnr_gth = self.lsnr_fn.forward(stft_clean, stft_noise) | |
| # lsnr_gth shape: [b, t] | |
| loss = F.mse_loss(lsnr, lsnr_gth) * self.factor | |
| return loss | |
| def main(): | |
| batch_size = 2 | |
| signal_length = 16000 | |
| estimated_signal = torch.randn(batch_size, signal_length) | |
| # target_signal = torch.randn(batch_size, signal_length) | |
| target_signal = torch.zeros(batch_size, signal_length) | |
| si_snr_loss = NegativeSISNRLoss() | |
| loss = si_snr_loss.forward(estimated_signal, target_signal) | |
| print(f"loss: {loss.item()}") | |
| return | |
| if __name__ == "__main__": | |
| main() | |