Spaces:
Paused
Paused
| #!/usr/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| https://zhuanlan.zhihu.com/p/627039860 | |
| https://github.com/facebookresearch/denoiser/blob/main/denoiser/stft_loss.py | |
| """ | |
| from typing import List | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| class LSDLoss(nn.Module): | |
| """ | |
| Log Spectral Distance | |
| Mean square error of power spectrum | |
| """ | |
| def __init__(self, | |
| n_fft: int = 512, | |
| win_size: int = 512, | |
| hop_size: int = 256, | |
| center: bool = True, | |
| eps: float = 1e-8, | |
| reduction: str = "mean", | |
| ): | |
| super(LSDLoss, self).__init__() | |
| self.n_fft = n_fft | |
| self.win_size = win_size | |
| self.hop_size = hop_size | |
| self.center = center | |
| self.eps = eps | |
| self.reduction = reduction | |
| if reduction not in ("sum", "mean"): | |
| raise AssertionError(f"param reduction must be sum or mean.") | |
| def forward(self, denoise_power: torch.Tensor, clean_power: torch.Tensor): | |
| """ | |
| :param denoise_power: power spectrum of the estimated signal power spectrum (batch_size, ...) | |
| :param clean_power: power spectrum of the target signal (batch_size, ...) | |
| :return: | |
| """ | |
| denoise_power = denoise_power + self.eps | |
| clean_power = clean_power + self.eps | |
| log_denoise_power = torch.log10(denoise_power) | |
| log_clean_power = torch.log10(clean_power) | |
| # mean_square_error shape: [b, f] | |
| mean_square_error = torch.mean(torch.square(log_denoise_power - log_clean_power), dim=-1) | |
| if self.reduction == "mean": | |
| lsd_loss = torch.mean(mean_square_error) | |
| elif self.reduction == "sum": | |
| lsd_loss = torch.sum(mean_square_error) | |
| else: | |
| raise AssertionError | |
| return lsd_loss | |
| class ComplexSpectralLoss(nn.Module): | |
| def __init__(self, | |
| n_fft: int = 512, | |
| win_size: int = 512, | |
| hop_size: int = 256, | |
| center: bool = True, | |
| eps: float = 1e-8, | |
| reduction: str = "mean", | |
| factor_mag: float = 0.5, | |
| factor_pha: float = 0.3, | |
| factor_gra: float = 0.2, | |
| ): | |
| super().__init__() | |
| self.n_fft = n_fft | |
| self.win_size = win_size | |
| self.hop_size = hop_size | |
| self.center = center | |
| self.eps = eps | |
| self.reduction = reduction | |
| self.factor_mag = factor_mag | |
| self.factor_pha = factor_pha | |
| self.factor_gra = factor_gra | |
| if reduction not in ("sum", "mean"): | |
| raise AssertionError(f"param reduction must be sum or mean.") | |
| self.window = nn.Parameter(torch.hann_window(win_size), requires_grad=False) | |
| def forward(self, denoise: torch.Tensor, clean: torch.Tensor): | |
| """ | |
| :param denoise: The estimated signal (batch_size, signal_length) | |
| :param clean: The target signal (batch_size, signal_length) | |
| :return: | |
| """ | |
| if denoise.shape != clean.shape: | |
| raise AssertionError("Input signals must have the same shape") | |
| # denoise_stft, clean_stft shape: [b, f, t] | |
| denoise_stft = torch.stft( | |
| denoise, | |
| n_fft=self.n_fft, | |
| win_length=self.win_size, | |
| hop_length=self.hop_size, | |
| window=self.window, | |
| center=self.center, | |
| pad_mode="reflect", | |
| normalized=False, | |
| return_complex=True | |
| ) | |
| clean_stft = torch.stft( | |
| clean, | |
| n_fft=self.n_fft, | |
| win_length=self.win_size, | |
| hop_length=self.hop_size, | |
| window=self.window, | |
| center=self.center, | |
| pad_mode="reflect", | |
| normalized=False, | |
| return_complex=True | |
| ) | |
| # complex_diff shape: [b, f, t], dtype: torch.complex64 | |
| complex_diff = denoise_stft - clean_stft | |
| # magnitude_diff, phase_diff shape: [b, f, t], dtype: torch.float32 | |
| magnitude_diff = torch.abs(complex_diff) | |
| phase_diff = torch.angle(complex_diff) | |
| # magnitude_loss, phase_loss shape: [b,] | |
| magnitude_loss = torch.norm(magnitude_diff, p=2, dim=(-1, -2)) | |
| phase_loss = torch.norm(phase_diff, p=1, dim=(-1, -2)) | |
| # phase_grad shape: [b, f, t-1], dtype: torch.float32 | |
| phase_grad = torch.diff(torch.angle(denoise_stft), dim=-1) | |
| grad_loss = torch.mean(torch.abs(phase_grad), dim=(-1, -2)) | |
| # loss, grad_loss shape: [b,] | |
| batch_loss = self.factor_mag * magnitude_loss + self.factor_pha * phase_loss + self.factor_gra * grad_loss | |
| # print(f"magnitude_loss: {magnitude_loss}") | |
| # print(f"phase_loss: {phase_loss}") | |
| # print(f"grad_loss: {grad_loss}") | |
| if self.reduction == "mean": | |
| loss = torch.mean(batch_loss) | |
| elif self.reduction == "sum": | |
| loss = torch.sum(batch_loss) | |
| else: | |
| raise AssertionError | |
| return loss | |
| class SpectralConvergenceLoss(torch.nn.Module): | |
| """Spectral convergence loss module.""" | |
| def __init__(self, | |
| reduction: str = "mean", | |
| eps: float = 1e-8, | |
| ): | |
| super(SpectralConvergenceLoss, self).__init__() | |
| self.reduction = reduction | |
| self.eps = eps | |
| if reduction not in ("sum", "mean"): | |
| raise AssertionError(f"param reduction must be sum or mean.") | |
| def forward(self, | |
| denoise_magnitude: torch.Tensor, | |
| clean_magnitude: torch.Tensor, | |
| ): | |
| """ | |
| :param denoise_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins] | |
| :param clean_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins] | |
| :return: | |
| """ | |
| error_norm = torch.norm(denoise_magnitude - clean_magnitude, p="fro", dim=(-1, -2)) | |
| truth_norm = torch.norm(clean_magnitude, p="fro", dim=(-1, -2)) | |
| batch_loss = error_norm / (truth_norm + self.eps) | |
| if self.reduction == "mean": | |
| loss = torch.mean(batch_loss) | |
| elif self.reduction == "sum": | |
| loss = torch.sum(batch_loss) | |
| else: | |
| raise AssertionError | |
| return loss | |
| class LogSTFTMagnitudeLoss(torch.nn.Module): | |
| """Log STFT magnitude loss module.""" | |
| def __init__(self, | |
| reduction: str = "mean", | |
| eps: float = 1e-8, | |
| ): | |
| super(LogSTFTMagnitudeLoss, self).__init__() | |
| self.reduction = reduction | |
| self.eps = eps | |
| if reduction not in ("sum", "mean"): | |
| raise AssertionError(f"param reduction must be sum or mean.") | |
| def forward(self, | |
| denoise_magnitude: torch.Tensor, | |
| clean_magnitude: torch.Tensor, | |
| ): | |
| """ | |
| :param denoise_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins] | |
| :param clean_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins] | |
| :return: | |
| """ | |
| loss = F.l1_loss(torch.log(denoise_magnitude + self.eps), torch.log(clean_magnitude + self.eps)) | |
| if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)): | |
| print("LogSTFTMagnitudeLoss, nan or inf in loss") | |
| return loss | |
| class STFTLoss(torch.nn.Module): | |
| """STFT loss module.""" | |
| def __init__(self, | |
| n_fft: int = 1024, | |
| win_size: int = 600, | |
| hop_size: int = 120, | |
| center: bool = True, | |
| reduction: str = "mean", | |
| ): | |
| super(STFTLoss, self).__init__() | |
| self.n_fft = n_fft | |
| self.win_size = win_size | |
| self.hop_size = hop_size | |
| self.center = center | |
| self.reduction = reduction | |
| self.window = nn.Parameter(torch.hann_window(win_size), requires_grad=False) | |
| self.spectral_convergence_loss = SpectralConvergenceLoss(reduction=reduction) | |
| self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss(reduction=reduction) | |
| def forward(self, denoise: torch.Tensor, clean: torch.Tensor): | |
| """ | |
| :param denoise: | |
| :param clean: | |
| :return: | |
| """ | |
| if denoise.shape != clean.shape: | |
| raise AssertionError("Input signals must have the same shape") | |
| # denoise_stft, clean_stft shape: [b, f, t] | |
| denoise_stft = torch.stft( | |
| denoise, | |
| n_fft=self.n_fft, | |
| win_length=self.win_size, | |
| hop_length=self.hop_size, | |
| window=self.window, | |
| center=self.center, | |
| pad_mode="reflect", | |
| normalized=False, | |
| return_complex=True | |
| ) | |
| clean_stft = torch.stft( | |
| clean, | |
| n_fft=self.n_fft, | |
| win_length=self.win_size, | |
| hop_length=self.hop_size, | |
| window=self.window, | |
| center=self.center, | |
| pad_mode="reflect", | |
| normalized=False, | |
| return_complex=True | |
| ) | |
| denoise_magnitude = torch.abs(denoise_stft) | |
| clean_magnitude = torch.abs(clean_stft) | |
| sc_loss = self.spectral_convergence_loss.forward(denoise_magnitude, clean_magnitude) | |
| mag_loss = self.log_stft_magnitude_loss.forward(denoise_magnitude, clean_magnitude) | |
| return sc_loss, mag_loss | |
| class MultiResolutionSTFTLoss(torch.nn.Module): | |
| """Multi resolution STFT loss module.""" | |
| def __init__(self, | |
| fft_size_list: List[int] = None, | |
| win_size_list: List[int] = None, | |
| hop_size_list: List[int] = None, | |
| factor_sc=0.1, | |
| factor_mag=0.1, | |
| reduction: str = "mean", | |
| ): | |
| super(MultiResolutionSTFTLoss, self).__init__() | |
| fft_size_list = fft_size_list or [512, 1024, 2048] | |
| win_size_list = win_size_list or [240, 600, 1200] | |
| hop_size_list = hop_size_list or [50, 120, 240] | |
| if not len(fft_size_list) == len(win_size_list) == len(hop_size_list): | |
| raise AssertionError | |
| loss_fn_list = nn.ModuleList([]) | |
| for n_fft, win_size, hop_size in zip(fft_size_list, win_size_list, hop_size_list): | |
| loss_fn_list.append( | |
| STFTLoss( | |
| n_fft=n_fft, | |
| win_size=win_size, | |
| hop_size=hop_size, | |
| reduction=reduction, | |
| ) | |
| ) | |
| self.loss_fn_list = loss_fn_list | |
| self.factor_sc = factor_sc | |
| self.factor_mag = factor_mag | |
| def forward(self, denoise: torch.Tensor, clean: torch.Tensor): | |
| """ | |
| :param denoise: | |
| :param clean: | |
| :return: | |
| """ | |
| if denoise.shape != clean.shape: | |
| raise AssertionError(f"Input signals must have the same shape. denoise_audios: {denoise.shape}, clean_audios: {clean.shape}") | |
| sc_loss = 0.0 | |
| mag_loss = 0.0 | |
| for loss_fn in self.loss_fn_list: | |
| sc_l, mag_l = loss_fn.forward(denoise, clean) | |
| sc_loss += sc_l | |
| mag_loss += mag_l | |
| sc_loss = sc_loss / len(self.loss_fn_list) | |
| mag_loss = mag_loss / len(self.loss_fn_list) | |
| sc_loss = self.factor_sc * sc_loss | |
| mag_loss = self.factor_mag * mag_loss | |
| loss = sc_loss + mag_loss | |
| return loss | |
| class WeightedMagnitudePhaseLoss(nn.Module): | |
| def __init__(self, | |
| n_fft: int = 1024, | |
| win_size: int = 600, | |
| hop_size: int = 120, | |
| center: bool = True, | |
| reduction: str = "mean", | |
| mag_weight: float = 0.9, | |
| pha_weight: float = 0.3, | |
| ): | |
| super(WeightedMagnitudePhaseLoss, self).__init__() | |
| self.n_fft = n_fft | |
| self.win_size = win_size | |
| self.hop_size = hop_size | |
| self.center = center | |
| self.reduction = reduction | |
| self.mag_weight = mag_weight | |
| self.pha_weight = pha_weight | |
| self.window = nn.Parameter(torch.hann_window(win_size), requires_grad=False) | |
| def forward(self, denoise: torch.Tensor, clean: torch.Tensor): | |
| """ | |
| :param denoise: | |
| :param clean: | |
| :return: | |
| """ | |
| if denoise.shape != clean.shape: | |
| raise AssertionError("Input signals must have the same shape") | |
| # denoise_stft, clean_stft shape: [b, f, t] | |
| denoise_stft = torch.stft( | |
| denoise, | |
| n_fft=self.n_fft, | |
| win_length=self.win_size, | |
| hop_length=self.hop_size, | |
| window=self.window, | |
| center=self.center, | |
| pad_mode="reflect", | |
| normalized=False, | |
| return_complex=True | |
| ) | |
| clean_stft = torch.stft( | |
| clean, | |
| n_fft=self.n_fft, | |
| win_length=self.win_size, | |
| hop_length=self.hop_size, | |
| window=self.window, | |
| center=self.center, | |
| pad_mode="reflect", | |
| normalized=False, | |
| return_complex=True | |
| ) | |
| denoise_stft_spec = torch.view_as_real(denoise_stft) | |
| denoise_mag = torch.sqrt(denoise_stft_spec.pow(2).sum(-1) + 1e-9) | |
| denoise_pha = torch.atan2(denoise_stft_spec[:, :, :, 1] + 1e-10, denoise_stft_spec[:, :, :, 0] + 1e-5) | |
| clean_stft_spec = torch.view_as_real(clean_stft) | |
| clean_mag = torch.sqrt(clean_stft_spec.pow(2).sum(-1) + 1e-9) | |
| clean_pha = torch.atan2(clean_stft_spec[:, :, :, 1] + 1e-10, clean_stft_spec[:, :, :, 0] + 1e-5) | |
| mag_loss = F.mse_loss(denoise_mag, clean_mag, reduction=self.reduction) | |
| pha_loss = F.mse_loss(denoise_pha, clean_pha, reduction=self.reduction) | |
| loss = self.mag_weight * mag_loss + self.pha_weight * pha_loss | |
| return loss | |
| def main(): | |
| batch_size = 2 | |
| signal_length = 16000 | |
| estimated_signal = torch.randn(batch_size, signal_length) | |
| target_signal = torch.randn(batch_size, signal_length) | |
| # loss_fn = LSDLoss() | |
| # loss_fn = ComplexSpectralLoss() | |
| # loss_fn = MultiResolutionSTFTLoss() | |
| loss_fn = WeightedMagnitudePhaseLoss() | |
| loss = loss_fn.forward(estimated_signal, target_signal) | |
| print(f"loss: {loss.item()}") | |
| return | |
| if __name__ == "__main__": | |
| main() | |