Spaces:
Running on Zero
Running on Zero
| import argparse | |
| from typing import Any, Callable, Optional, Union | |
| import auraloss | |
| import torch | |
| import torch.nn.functional as F | |
| from ml_collections import ConfigDict | |
| from torch import nn | |
| from torch_log_wmse import LogWMSE | |
| def multistft_loss( | |
| y_: torch.Tensor, | |
| y: torch.Tensor, | |
| loss_multistft: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], | |
| ) -> torch.Tensor: | |
| """ | |
| Compute a (multi-resolution) STFT-based loss on waveforms. | |
| Reshapes inputs to (B, C*T, L) when needed and delegates to a provided | |
| multi-resolution STFT criterion (e.g., `auraloss.freq.MultiResolutionSTFTLoss`), | |
| a widely used spectral loss for audio synthesis/enhancement that compares | |
| magnitudes across multiple STFT settings. | |
| See: Steinmetz & Reiss, 2020, “auraloss: Audio-focused loss functions in PyTorch”. | |
| Args: | |
| y_ (torch.Tensor): Predicted waveform tensor of shape (B, C, T) or (B, S, C, T). | |
| y (torch.Tensor): Target waveform tensor with a compatible shape. | |
| loss_multistft (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]): | |
| A callable implementing the MR-STFT loss. | |
| Returns: | |
| torch.Tensor: Scalar loss tensor. | |
| """ | |
| if len(y_.shape) == 4: | |
| y1_ = y_.reshape(y_.shape[0], y_.shape[1] * y_.shape[2], y_.shape[3]) | |
| elif len(y_.shape) == 3: | |
| y1_ = y_ | |
| if len(y.shape) == 4: | |
| y1 = y.reshape(y.shape[0], y.shape[1] * y.shape[2], y.shape[3]) | |
| elif len(y_.shape) == 3: | |
| y1 = y | |
| if len(y_.shape) not in [3, 4]: | |
| raise ValueError( | |
| f"Invalid shape for predicted array: {y_.shape}. Expected 3 or 4 dimensions." | |
| ) | |
| return loss_multistft(y1_, y1) | |
| def masked_loss( | |
| y_: torch.Tensor, y: torch.Tensor, q: float, coarse: bool = True | |
| ) -> torch.Tensor: | |
| """ | |
| Robust, quantile-masked MSE (“trimmed” MSE). | |
| Computes an elementwise MSE, optionally averages spatial dims (“coarse”), | |
| then masks out the largest residuals by keeping values below the `q`-quantile. | |
| This yields robustness to outliers akin to trimmed/robust regression losses. | |
| See classical robust estimation: Huber, 1964; Rousseeuw & Leroy, 1987. | |
| Args: | |
| y_ (torch.Tensor): Predicted tensor matching `y`'s shape. | |
| y (torch.Tensor): Ground-truth tensor. | |
| q (float): Quantile in (0, 1] used to keep low-error elements. | |
| coarse (bool, optional): If True, average over last two dims before masking. | |
| Defaults to True. | |
| Returns: | |
| torch.Tensor: Scalar loss tensor. | |
| """ | |
| loss = torch.nn.MSELoss(reduction="none")(y_, y).transpose(0, 1) | |
| if coarse: | |
| loss = loss.mean(dim=(-1, -2)) | |
| loss = loss.reshape(loss.shape[0], -1) | |
| quantile = torch.quantile( | |
| loss.detach(), q, interpolation="linear", dim=1, keepdim=True | |
| ) | |
| mask = loss < quantile | |
| return (loss * mask).mean() | |
| def spec_rmse_loss( | |
| estimate: torch.Tensor, sources: torch.Tensor, stft_config: dict, eps: float = 1e-8 | |
| ) -> torch.Tensor: | |
| """ | |
| RMSE in the complex STFT domain. | |
| Computes STFT for prediction and target, represents complex values as | |
| real+imag pairs, and applies RMSE (L2) over the spectral representation. | |
| Spectral-domain L2/RMSE losses are common in speech/music enhancement. | |
| See, e.g., Steinmetz & Reiss, 2020; Yamamoto et al., 2020 (Parallel WaveGAN). | |
| Args: | |
| estimate (torch.Tensor): Predicted time-domain signal(s), e.g., (B, S, C, T). | |
| sources (torch.Tensor): Target time-domain signal(s), matching shape. | |
| stft_config (dict): Parameters for `torch.stft` (e.g., n_fft, hop_length, win_length). | |
| Returns: | |
| torch.Tensor: Scalar loss tensor. | |
| """ | |
| lenc = estimate.shape[-1] | |
| spec_estimate = estimate.view(-1, lenc) | |
| spec_sources = sources.view(-1, lenc) | |
| spec_estimate = torch.stft(spec_estimate, **stft_config, return_complex=True) | |
| spec_sources = torch.stft(spec_sources, **stft_config, return_complex=True) | |
| spec_estimate = torch.view_as_real(spec_estimate) | |
| spec_sources = torch.view_as_real(spec_sources) | |
| new_shape = estimate.shape[:-1] + spec_estimate.shape[-3:] | |
| spec_estimate = spec_estimate.view(*new_shape) | |
| spec_sources = spec_sources.view(*new_shape) | |
| loss = F.mse_loss(spec_estimate, spec_sources, reduction="none") | |
| dims = tuple(range(2, loss.dim())) | |
| loss = (loss.mean(dims) + eps).sqrt().mean(dim=(0, 1)) | |
| return loss | |
| def spec_masked_loss( | |
| estimate: torch.Tensor, | |
| sources: torch.Tensor, | |
| stft_config: dict, | |
| q: float = 0.9, | |
| coarse: bool = True, | |
| ) -> torch.Tensor: | |
| """ | |
| Quantile-masked MSE in the complex STFT domain. | |
| Computes a complex STFT for prediction and target, forms an elementwise MSE | |
| in the spectral domain, optionally averages spatial/frequency dims (“coarse”), | |
| and masks out the highest-error elements using the `q`-quantile threshold for | |
| robustness to outliers. Related to trimmed/robust spectral losses. | |
| See: Huber, 1964; Rousseeuw & Leroy, 1987; spectral losses as in Steinmetz & Reiss, 2020. | |
| Args: | |
| estimate (torch.Tensor): Predicted time-domain signal(s), e.g., (B, S, C, T). | |
| sources (torch.Tensor): Target time-domain signal(s), matching shape. | |
| stft_config (dict): Parameters for `torch.stft`. | |
| q (float, optional): Quantile in (0, 1] to keep low-error elements. Defaults to 0.9. | |
| coarse (bool, optional): If True, average over spectral dims before masking. Defaults to True. | |
| Returns: | |
| torch.Tensor: Scalar loss tensor. | |
| """ | |
| lenc = estimate.shape[-1] | |
| spec_estimate = estimate.view(-1, lenc) | |
| spec_sources = sources.view(-1, lenc) | |
| spec_estimate = torch.stft(spec_estimate, **stft_config, return_complex=True) | |
| spec_sources = torch.stft(spec_sources, **stft_config, return_complex=True) | |
| spec_estimate = torch.view_as_real(spec_estimate) | |
| spec_sources = torch.view_as_real(spec_sources) | |
| new_shape = estimate.shape[:-1] + spec_estimate.shape[-3:] | |
| spec_estimate = spec_estimate.view(*new_shape) | |
| spec_sources = spec_sources.view(*new_shape) | |
| loss = F.mse_loss(spec_estimate, spec_sources, reduction="none") | |
| if coarse: | |
| loss = loss.mean(dim=(-3, -2)) | |
| loss = loss.reshape(loss.shape[0], -1) | |
| quantile = torch.quantile( | |
| loss.detach(), q, interpolation="linear", dim=1, keepdim=True | |
| ) | |
| mask = loss < quantile | |
| masked_loss = (loss * mask).mean() | |
| return masked_loss | |
| def l1_snr_loss(y_: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
| """ | |
| L1-SNR loss in time domain. | |
| L1-based signal-to-noise ratio loss (without additional regularization). | |
| From torch-l1-snr package. | |
| Args: | |
| y_ (torch.Tensor): Predicted waveform tensor. | |
| y (torch.Tensor): Target waveform tensor. | |
| Returns: | |
| torch.Tensor: Scalar loss tensor. | |
| """ | |
| from torch_l1_snr import L1SNRLoss | |
| loss_fn = L1SNRLoss(name="l1_snr") | |
| return loss_fn(y_, y) | |
| def l1_snr_db_loss(y_: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
| """ | |
| L1-SNR loss with dB-scale level regularization. | |
| Extends L1-SNR with adaptive level-matching regularization in dB scale. | |
| From torch-l1-snr package. | |
| Args: | |
| y_ (torch.Tensor): Predicted waveform tensor. | |
| y (torch.Tensor): Target waveform tensor. | |
| Returns: | |
| torch.Tensor: Scalar loss tensor. | |
| """ | |
| from torch_l1_snr import L1SNRDBLoss | |
| loss_fn = L1SNRDBLoss(name="l1_snr_db") | |
| return loss_fn(y_, y) | |
| def stft_l1_snr_db_loss(y_: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
| """ | |
| L1-SNR loss in multi-resolution STFT domain. | |
| Applies L1-SNR to complex STFT (real/imaginary) across multiple resolutions. | |
| From torch-l1-snr package. | |
| Args: | |
| y_ (torch.Tensor): Predicted waveform tensor. | |
| y (torch.Tensor): Target waveform tensor. | |
| Returns: | |
| torch.Tensor: Scalar loss tensor. | |
| """ | |
| from torch_l1_snr import STFTL1SNRDBLoss | |
| loss_fn = STFTL1SNRDBLoss(name="stft_l1_snr_db") | |
| return loss_fn(y_, y) | |
| def multi_l1_snr_db_loss(y_: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Combined time + STFT domain L1-SNR loss. | |
| Balances time-domain and spectral-domain L1-SNR with optional regularization. | |
| This is the recommended loss from torch-l1-snr for most use cases. | |
| From torch-l1-snr package. | |
| Args: | |
| y_ (torch.Tensor): Predicted waveform tensor. | |
| y (torch.Tensor): Target waveform tensor. | |
| Returns: | |
| torch.Tensor: Scalar loss tensor. | |
| """ | |
| from torch_l1_snr import MultiL1SNRDBLoss | |
| loss_fn = MultiL1SNRDBLoss(name="multi_l1_snr_db") | |
| return loss_fn(y_, y) | |
| def choice_loss( | |
| args: argparse.Namespace, config: ConfigDict | |
| ) -> Callable[[Any, Any, Union[Any, None]], torch.Tensor]: | |
| """ | |
| Build a composite loss from CLI/config options. | |
| Returns a callable that sums enabled terms (with per-term coefficients): | |
| - `masked_loss`: robust, quantile-masked MSE (trimmed MSE; Huber, 1964; Rousseeuw & Leroy, 1987). | |
| - `mse_loss`: standard mean squared error. | |
| - `l1_loss`: mean absolute error. | |
| - `multistft_loss`: multi-resolution STFT magnitude loss (Steinmetz & Reiss, 2020). | |
| - `log_wmse_loss`: weighted MSE operating in a log/spectral perceptual space (log-weighted MSE). | |
| - `l1_snr_loss`: L1-SNR loss in time domain (Watcharasupat et al., 2023). | |
| - `l1_snr_db_loss`: L1-SNR with dB-scale level regularization. | |
| - `stft_l1_snr_db_loss`: L1-SNR in multi-resolution STFT domain. | |
| - `multi_l1_snr_db_loss`: combined time + STFT domain L1-SNR (recommended). | |
| - `spec_rmse_loss`: RMSE in complex STFT domain. | |
| - `spec_masked_loss`: quantile-masked spectral MSE (robust spectral loss). | |
| Args: | |
| args (argparse.Namespace): Parsed arguments specifying which losses are active | |
| and their coefficients. | |
| config (ConfigDict): Configuration with loss hyperparameters (e.g., STFT settings, | |
| quantile `q`, coarse masking flag). | |
| Returns: | |
| Callable[[Any, Any, Optional[Any]], torch.Tensor]: A function `loss(y_pred, y_true, x=None)` | |
| that computes the weighted sum of the selected loss terms. | |
| """ | |
| loss_fns = [] | |
| if "masked_loss" in args.loss: | |
| loss_fns.append( | |
| lambda y_pred, y_true, x=None: masked_loss( | |
| y_pred, | |
| y_true, | |
| q=config["training"]["q"], | |
| coarse=config["training"]["coarse_loss_clip"], | |
| ) | |
| * args.masked_loss_coef | |
| ) | |
| if "mse_loss" in args.loss: | |
| mse = nn.MSELoss() | |
| loss_fns.append( | |
| lambda y_pred, y_true, x=None: mse(y_pred, y_true) * args.mse_loss_coef | |
| ) | |
| if "l1_loss" in args.loss: | |
| loss_fns.append( | |
| lambda y_pred, y_true, x=None: F.l1_loss(y_pred, y_true) * args.l1_loss_coef | |
| ) | |
| if "multistft_loss" in args.loss: | |
| loss_options = dict(config.get("loss_multistft", {})) | |
| stft_loss = auraloss.freq.MultiResolutionSTFTLoss(**loss_options) | |
| loss_fns.append( | |
| lambda y_pred, y_true, x=None: multistft_loss(y_pred, y_true, stft_loss) | |
| * args.multistft_loss_coef | |
| ) | |
| if "log_wmse_loss" in args.loss: | |
| log_wmse = LogWMSE( | |
| audio_length=int(getattr(config.audio, "chunk_size", 485100)) | |
| // int(getattr(config.audio, "sample_rate", 44100)), | |
| sample_rate=int(getattr(config.audio, "sample_rate", 44100)), | |
| return_as_loss=True, | |
| bypass_filter=getattr(config.training, "bypass_filter", False), | |
| ) | |
| loss_fns.append( | |
| lambda y_pred, y_true, x: log_wmse(x, y_pred, y_true) | |
| * args.log_wmse_loss_coef | |
| ) | |
| if "l1_snr_loss" in args.loss: | |
| loss_fns.append( | |
| lambda y_pred, y_true, x=None: l1_snr_loss(y_pred, y_true) | |
| * args.l1_snr_loss_coef | |
| ) | |
| if "l1_snr_db_loss" in args.loss: | |
| loss_fns.append( | |
| lambda y_pred, y_true, x=None: l1_snr_db_loss(y_pred, y_true) | |
| * args.l1_snr_db_loss_coef | |
| ) | |
| if "stft_l1_snr_db_loss" in args.loss: | |
| loss_fns.append( | |
| lambda y_pred, y_true, x=None: stft_l1_snr_db_loss(y_pred, y_true) | |
| * args.stft_l1_snr_db_loss_coef | |
| ) | |
| if "multi_l1_snr_db_loss" in args.loss: | |
| loss_fns.append( | |
| lambda y_pred, y_true, x=None: multi_l1_snr_db_loss(y_pred, y_true) | |
| * args.multi_l1_snr_db_loss_coef | |
| ) | |
| if "spec_rmse_loss" in args.loss: | |
| stft_config = { | |
| "n_fft": getattr(config.model, "nfft", 4096), | |
| "hop_length": getattr(config.model, "hop_size", 1024), | |
| "win_length": getattr(config.model, "win_size", 4096), | |
| "center": True, | |
| "normalized": getattr(config.model, "normalized", True), | |
| } | |
| loss_fns.append( | |
| lambda y_pred, y_true, x=None: spec_rmse_loss(y_pred, y_true, stft_config) | |
| * args.spec_rmse_loss_coef | |
| ) | |
| if "spec_masked_loss" in args.loss: | |
| stft_config = { | |
| "n_fft": getattr(config.model, "nfft", 4096), | |
| "hop_length": getattr(config.model, "hop_size", 1024), | |
| "win_length": getattr(config.model, "win_size", 4096), | |
| "center": True, | |
| "normalized": getattr(config.model, "normalized", True), | |
| } | |
| loss_fns.append( | |
| lambda y_pred, y_true, x=None: spec_masked_loss( | |
| y_pred, | |
| y_true, | |
| stft_config, | |
| q=config["training"]["q"], | |
| coarse=config["training"]["coarse_loss_clip"], | |
| ) | |
| * args.spec_masked_loss_coef | |
| ) | |
| def multi_loss(y_pred: Any, y_true: Any, x: Optional[Any] = None) -> torch.Tensor: | |
| total = 0 | |
| for fn in loss_fns: | |
| total = total + fn(y_pred, y_true, x) | |
| return total | |
| return multi_loss | |