Spaces:
Running on Zero
Running on Zero
File size: 14,011 Bytes
64ec292 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 | import argparse
from typing import Any, Callable, Optional, Union
import auraloss
import torch
import torch.nn.functional as F
from ml_collections import ConfigDict
from torch import nn
from torch_log_wmse import LogWMSE
def multistft_loss(
y_: torch.Tensor,
y: torch.Tensor,
loss_multistft: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
) -> torch.Tensor:
"""
Compute a (multi-resolution) STFT-based loss on waveforms.
Reshapes inputs to (B, C*T, L) when needed and delegates to a provided
multi-resolution STFT criterion (e.g., `auraloss.freq.MultiResolutionSTFTLoss`),
a widely used spectral loss for audio synthesis/enhancement that compares
magnitudes across multiple STFT settings.
See: Steinmetz & Reiss, 2020, “auraloss: Audio-focused loss functions in PyTorch”.
Args:
y_ (torch.Tensor): Predicted waveform tensor of shape (B, C, T) or (B, S, C, T).
y (torch.Tensor): Target waveform tensor with a compatible shape.
loss_multistft (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]):
A callable implementing the MR-STFT loss.
Returns:
torch.Tensor: Scalar loss tensor.
"""
if len(y_.shape) == 4:
y1_ = y_.reshape(y_.shape[0], y_.shape[1] * y_.shape[2], y_.shape[3])
elif len(y_.shape) == 3:
y1_ = y_
if len(y.shape) == 4:
y1 = y.reshape(y.shape[0], y.shape[1] * y.shape[2], y.shape[3])
elif len(y_.shape) == 3:
y1 = y
if len(y_.shape) not in [3, 4]:
raise ValueError(
f"Invalid shape for predicted array: {y_.shape}. Expected 3 or 4 dimensions."
)
return loss_multistft(y1_, y1)
def masked_loss(
y_: torch.Tensor, y: torch.Tensor, q: float, coarse: bool = True
) -> torch.Tensor:
"""
Robust, quantile-masked MSE (“trimmed” MSE).
Computes an elementwise MSE, optionally averages spatial dims (“coarse”),
then masks out the largest residuals by keeping values below the `q`-quantile.
This yields robustness to outliers akin to trimmed/robust regression losses.
See classical robust estimation: Huber, 1964; Rousseeuw & Leroy, 1987.
Args:
y_ (torch.Tensor): Predicted tensor matching `y`'s shape.
y (torch.Tensor): Ground-truth tensor.
q (float): Quantile in (0, 1] used to keep low-error elements.
coarse (bool, optional): If True, average over last two dims before masking.
Defaults to True.
Returns:
torch.Tensor: Scalar loss tensor.
"""
loss = torch.nn.MSELoss(reduction="none")(y_, y).transpose(0, 1)
if coarse:
loss = loss.mean(dim=(-1, -2))
loss = loss.reshape(loss.shape[0], -1)
quantile = torch.quantile(
loss.detach(), q, interpolation="linear", dim=1, keepdim=True
)
mask = loss < quantile
return (loss * mask).mean()
def spec_rmse_loss(
estimate: torch.Tensor, sources: torch.Tensor, stft_config: dict, eps: float = 1e-8
) -> torch.Tensor:
"""
RMSE in the complex STFT domain.
Computes STFT for prediction and target, represents complex values as
real+imag pairs, and applies RMSE (L2) over the spectral representation.
Spectral-domain L2/RMSE losses are common in speech/music enhancement.
See, e.g., Steinmetz & Reiss, 2020; Yamamoto et al., 2020 (Parallel WaveGAN).
Args:
estimate (torch.Tensor): Predicted time-domain signal(s), e.g., (B, S, C, T).
sources (torch.Tensor): Target time-domain signal(s), matching shape.
stft_config (dict): Parameters for `torch.stft` (e.g., n_fft, hop_length, win_length).
Returns:
torch.Tensor: Scalar loss tensor.
"""
lenc = estimate.shape[-1]
spec_estimate = estimate.view(-1, lenc)
spec_sources = sources.view(-1, lenc)
spec_estimate = torch.stft(spec_estimate, **stft_config, return_complex=True)
spec_sources = torch.stft(spec_sources, **stft_config, return_complex=True)
spec_estimate = torch.view_as_real(spec_estimate)
spec_sources = torch.view_as_real(spec_sources)
new_shape = estimate.shape[:-1] + spec_estimate.shape[-3:]
spec_estimate = spec_estimate.view(*new_shape)
spec_sources = spec_sources.view(*new_shape)
loss = F.mse_loss(spec_estimate, spec_sources, reduction="none")
dims = tuple(range(2, loss.dim()))
loss = (loss.mean(dims) + eps).sqrt().mean(dim=(0, 1))
return loss
def spec_masked_loss(
estimate: torch.Tensor,
sources: torch.Tensor,
stft_config: dict,
q: float = 0.9,
coarse: bool = True,
) -> torch.Tensor:
"""
Quantile-masked MSE in the complex STFT domain.
Computes a complex STFT for prediction and target, forms an elementwise MSE
in the spectral domain, optionally averages spatial/frequency dims (“coarse”),
and masks out the highest-error elements using the `q`-quantile threshold for
robustness to outliers. Related to trimmed/robust spectral losses.
See: Huber, 1964; Rousseeuw & Leroy, 1987; spectral losses as in Steinmetz & Reiss, 2020.
Args:
estimate (torch.Tensor): Predicted time-domain signal(s), e.g., (B, S, C, T).
sources (torch.Tensor): Target time-domain signal(s), matching shape.
stft_config (dict): Parameters for `torch.stft`.
q (float, optional): Quantile in (0, 1] to keep low-error elements. Defaults to 0.9.
coarse (bool, optional): If True, average over spectral dims before masking. Defaults to True.
Returns:
torch.Tensor: Scalar loss tensor.
"""
lenc = estimate.shape[-1]
spec_estimate = estimate.view(-1, lenc)
spec_sources = sources.view(-1, lenc)
spec_estimate = torch.stft(spec_estimate, **stft_config, return_complex=True)
spec_sources = torch.stft(spec_sources, **stft_config, return_complex=True)
spec_estimate = torch.view_as_real(spec_estimate)
spec_sources = torch.view_as_real(spec_sources)
new_shape = estimate.shape[:-1] + spec_estimate.shape[-3:]
spec_estimate = spec_estimate.view(*new_shape)
spec_sources = spec_sources.view(*new_shape)
loss = F.mse_loss(spec_estimate, spec_sources, reduction="none")
if coarse:
loss = loss.mean(dim=(-3, -2))
loss = loss.reshape(loss.shape[0], -1)
quantile = torch.quantile(
loss.detach(), q, interpolation="linear", dim=1, keepdim=True
)
mask = loss < quantile
masked_loss = (loss * mask).mean()
return masked_loss
def l1_snr_loss(y_: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
L1-SNR loss in time domain.
L1-based signal-to-noise ratio loss (without additional regularization).
From torch-l1-snr package.
Args:
y_ (torch.Tensor): Predicted waveform tensor.
y (torch.Tensor): Target waveform tensor.
Returns:
torch.Tensor: Scalar loss tensor.
"""
from torch_l1_snr import L1SNRLoss
loss_fn = L1SNRLoss(name="l1_snr")
return loss_fn(y_, y)
def l1_snr_db_loss(y_: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
L1-SNR loss with dB-scale level regularization.
Extends L1-SNR with adaptive level-matching regularization in dB scale.
From torch-l1-snr package.
Args:
y_ (torch.Tensor): Predicted waveform tensor.
y (torch.Tensor): Target waveform tensor.
Returns:
torch.Tensor: Scalar loss tensor.
"""
from torch_l1_snr import L1SNRDBLoss
loss_fn = L1SNRDBLoss(name="l1_snr_db")
return loss_fn(y_, y)
def stft_l1_snr_db_loss(y_: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
L1-SNR loss in multi-resolution STFT domain.
Applies L1-SNR to complex STFT (real/imaginary) across multiple resolutions.
From torch-l1-snr package.
Args:
y_ (torch.Tensor): Predicted waveform tensor.
y (torch.Tensor): Target waveform tensor.
Returns:
torch.Tensor: Scalar loss tensor.
"""
from torch_l1_snr import STFTL1SNRDBLoss
loss_fn = STFTL1SNRDBLoss(name="stft_l1_snr_db")
return loss_fn(y_, y)
def multi_l1_snr_db_loss(y_: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Combined time + STFT domain L1-SNR loss.
Balances time-domain and spectral-domain L1-SNR with optional regularization.
This is the recommended loss from torch-l1-snr for most use cases.
From torch-l1-snr package.
Args:
y_ (torch.Tensor): Predicted waveform tensor.
y (torch.Tensor): Target waveform tensor.
Returns:
torch.Tensor: Scalar loss tensor.
"""
from torch_l1_snr import MultiL1SNRDBLoss
loss_fn = MultiL1SNRDBLoss(name="multi_l1_snr_db")
return loss_fn(y_, y)
def choice_loss(
args: argparse.Namespace, config: ConfigDict
) -> Callable[[Any, Any, Union[Any, None]], torch.Tensor]:
"""
Build a composite loss from CLI/config options.
Returns a callable that sums enabled terms (with per-term coefficients):
- `masked_loss`: robust, quantile-masked MSE (trimmed MSE; Huber, 1964; Rousseeuw & Leroy, 1987).
- `mse_loss`: standard mean squared error.
- `l1_loss`: mean absolute error.
- `multistft_loss`: multi-resolution STFT magnitude loss (Steinmetz & Reiss, 2020).
- `log_wmse_loss`: weighted MSE operating in a log/spectral perceptual space (log-weighted MSE).
- `l1_snr_loss`: L1-SNR loss in time domain (Watcharasupat et al., 2023).
- `l1_snr_db_loss`: L1-SNR with dB-scale level regularization.
- `stft_l1_snr_db_loss`: L1-SNR in multi-resolution STFT domain.
- `multi_l1_snr_db_loss`: combined time + STFT domain L1-SNR (recommended).
- `spec_rmse_loss`: RMSE in complex STFT domain.
- `spec_masked_loss`: quantile-masked spectral MSE (robust spectral loss).
Args:
args (argparse.Namespace): Parsed arguments specifying which losses are active
and their coefficients.
config (ConfigDict): Configuration with loss hyperparameters (e.g., STFT settings,
quantile `q`, coarse masking flag).
Returns:
Callable[[Any, Any, Optional[Any]], torch.Tensor]: A function `loss(y_pred, y_true, x=None)`
that computes the weighted sum of the selected loss terms.
"""
loss_fns = []
if "masked_loss" in args.loss:
loss_fns.append(
lambda y_pred, y_true, x=None: masked_loss(
y_pred,
y_true,
q=config["training"]["q"],
coarse=config["training"]["coarse_loss_clip"],
)
* args.masked_loss_coef
)
if "mse_loss" in args.loss:
mse = nn.MSELoss()
loss_fns.append(
lambda y_pred, y_true, x=None: mse(y_pred, y_true) * args.mse_loss_coef
)
if "l1_loss" in args.loss:
loss_fns.append(
lambda y_pred, y_true, x=None: F.l1_loss(y_pred, y_true) * args.l1_loss_coef
)
if "multistft_loss" in args.loss:
loss_options = dict(config.get("loss_multistft", {}))
stft_loss = auraloss.freq.MultiResolutionSTFTLoss(**loss_options)
loss_fns.append(
lambda y_pred, y_true, x=None: multistft_loss(y_pred, y_true, stft_loss)
* args.multistft_loss_coef
)
if "log_wmse_loss" in args.loss:
log_wmse = LogWMSE(
audio_length=int(getattr(config.audio, "chunk_size", 485100))
// int(getattr(config.audio, "sample_rate", 44100)),
sample_rate=int(getattr(config.audio, "sample_rate", 44100)),
return_as_loss=True,
bypass_filter=getattr(config.training, "bypass_filter", False),
)
loss_fns.append(
lambda y_pred, y_true, x: log_wmse(x, y_pred, y_true)
* args.log_wmse_loss_coef
)
if "l1_snr_loss" in args.loss:
loss_fns.append(
lambda y_pred, y_true, x=None: l1_snr_loss(y_pred, y_true)
* args.l1_snr_loss_coef
)
if "l1_snr_db_loss" in args.loss:
loss_fns.append(
lambda y_pred, y_true, x=None: l1_snr_db_loss(y_pred, y_true)
* args.l1_snr_db_loss_coef
)
if "stft_l1_snr_db_loss" in args.loss:
loss_fns.append(
lambda y_pred, y_true, x=None: stft_l1_snr_db_loss(y_pred, y_true)
* args.stft_l1_snr_db_loss_coef
)
if "multi_l1_snr_db_loss" in args.loss:
loss_fns.append(
lambda y_pred, y_true, x=None: multi_l1_snr_db_loss(y_pred, y_true)
* args.multi_l1_snr_db_loss_coef
)
if "spec_rmse_loss" in args.loss:
stft_config = {
"n_fft": getattr(config.model, "nfft", 4096),
"hop_length": getattr(config.model, "hop_size", 1024),
"win_length": getattr(config.model, "win_size", 4096),
"center": True,
"normalized": getattr(config.model, "normalized", True),
}
loss_fns.append(
lambda y_pred, y_true, x=None: spec_rmse_loss(y_pred, y_true, stft_config)
* args.spec_rmse_loss_coef
)
if "spec_masked_loss" in args.loss:
stft_config = {
"n_fft": getattr(config.model, "nfft", 4096),
"hop_length": getattr(config.model, "hop_size", 1024),
"win_length": getattr(config.model, "win_size", 4096),
"center": True,
"normalized": getattr(config.model, "normalized", True),
}
loss_fns.append(
lambda y_pred, y_true, x=None: spec_masked_loss(
y_pred,
y_true,
stft_config,
q=config["training"]["q"],
coarse=config["training"]["coarse_loss_clip"],
)
* args.spec_masked_loss_coef
)
def multi_loss(y_pred: Any, y_true: Any, x: Optional[Any] = None) -> torch.Tensor:
total = 0
for fn in loss_fns:
total = total + fn(y_pred, y_true, x)
return total
return multi_loss
|