| from typing import Optional |
| import torch |
| import torch.nn as nn |
| from torch import Tensor |
| from torch.utils.data import DataLoader |
|
|
| def atan2(y, x): |
| r"""Element-wise arctangent function of y/x. |
| Returns a new tensor with signed angles in radians. |
| It is an alternative implementation of torch.atan2 |
| |
| Args: |
| y (Tensor): First input tensor |
| x (Tensor): Second input tensor [shape=y.shape] |
| |
| Returns: |
| Tensor: [shape=y.shape]. |
| """ |
| pi = 2 * torch.asin(torch.tensor(1.0)) |
| x += ((x == 0) & (y == 0)) * 1.0 |
| out = torch.atan(y / x) |
| out += ((y >= 0) & (x < 0)) * pi |
| out -= ((y < 0) & (x < 0)) * pi |
| out *= 1 - ((y > 0) & (x == 0)) * 1.0 |
| out += ((y > 0) & (x == 0)) * (pi / 2) |
| out *= 1 - ((y < 0) & (x == 0)) * 1.0 |
| out += ((y < 0) & (x == 0)) * (-pi / 2) |
| return out |
|
|
|
|
| |
| |
|
|
|
|
| def _norm(x: torch.Tensor) -> torch.Tensor: |
| r"""Computes the norm value of a torch Tensor, assuming that it |
| comes as real and imaginary part in its last dimension. |
| |
| Args: |
| x (Tensor): Input Tensor of shape [shape=(..., 2)] |
| |
| Returns: |
| Tensor: shape as x excluding the last dimension. |
| """ |
| return torch.abs(x[..., 0]) ** 2 + torch.abs(x[..., 1]) ** 2 |
|
|
|
|
| def _mul_add(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor: |
| """Element-wise multiplication of two complex Tensors described |
| through their real and imaginary parts. |
| The result is added to the `out` tensor""" |
|
|
| |
| target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)]) |
| if out is None or out.shape != target_shape: |
| out = torch.zeros(target_shape, dtype=a.dtype, device=a.device) |
| if out is a: |
| real_a = a[..., 0] |
| out[..., 0] = out[..., 0] + (real_a * b[..., 0] - a[..., 1] * b[..., 1]) |
| out[..., 1] = out[..., 1] + (real_a * b[..., 1] + a[..., 1] * b[..., 0]) |
| else: |
| out[..., 0] = out[..., 0] + (a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1]) |
| out[..., 1] = out[..., 1] + (a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0]) |
| return out |
|
|
|
|
| def _mul(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor: |
| """Element-wise multiplication of two complex Tensors described |
| through their real and imaginary parts |
| can work in place in case out is a only""" |
| target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)]) |
| if out is None or out.shape != target_shape: |
| out = torch.zeros(target_shape, dtype=a.dtype, device=a.device) |
| if out is a: |
| real_a = a[..., 0] |
| out[..., 0] = real_a * b[..., 0] - a[..., 1] * b[..., 1] |
| out[..., 1] = real_a * b[..., 1] + a[..., 1] * b[..., 0] |
| else: |
| out[..., 0] = a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1] |
| out[..., 1] = a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0] |
| return out |
|
|
|
|
| def _inv(z: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor: |
| """Element-wise multiplicative inverse of a Tensor with complex |
| entries described through their real and imaginary parts. |
| can work in place in case out is z""" |
| ez = _norm(z) |
| if out is None or out.shape != z.shape: |
| out = torch.zeros_like(z) |
| out[..., 0] = z[..., 0] / ez |
| out[..., 1] = -z[..., 1] / ez |
| return out |
|
|
|
|
| def _conj(z, out: Optional[torch.Tensor] = None) -> torch.Tensor: |
| """Element-wise complex conjugate of a Tensor with complex entries |
| described through their real and imaginary parts. |
| can work in place in case out is z""" |
| if out is None or out.shape != z.shape: |
| out = torch.zeros_like(z) |
| out[..., 0] = z[..., 0] |
| out[..., 1] = -z[..., 1] |
| return out |
|
|
|
|
| def _invert(M: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor: |
| """ |
| Invert 1x1 or 2x2 matrices |
| |
| Will generate errors if the matrices are singular: user must handle this |
| through his own regularization schemes. |
| |
| Args: |
| M (Tensor): [shape=(..., nb_channels, nb_channels, 2)] |
| matrices to invert: must be square along dimensions -3 and -2 |
| |
| Returns: |
| invM (Tensor): [shape=M.shape] |
| inverses of M |
| """ |
| nb_channels = M.shape[-2] |
|
|
| if out is None or out.shape != M.shape: |
| out = torch.empty_like(M) |
|
|
| if nb_channels == 1: |
| |
| out = _inv(M, out) |
| elif nb_channels == 2: |
| |
|
|
| |
| det = _mul(M[..., 0, 0, :], M[..., 1, 1, :]) |
| det = det - _mul(M[..., 0, 1, :], M[..., 1, 0, :]) |
| |
| invDet = _inv(det) |
|
|
| |
| out[..., 0, 0, :] = _mul(invDet, M[..., 1, 1, :], out[..., 0, 0, :]) |
| out[..., 1, 0, :] = _mul(-invDet, M[..., 1, 0, :], out[..., 1, 0, :]) |
| out[..., 0, 1, :] = _mul(-invDet, M[..., 0, 1, :], out[..., 0, 1, :]) |
| out[..., 1, 1, :] = _mul(invDet, M[..., 0, 0, :], out[..., 1, 1, :]) |
| else: |
| raise Exception("Only 2 channels are supported for the torch version.") |
| return out |
|
|
|
|
| |
|
|
|
|
| def expectation_maximization( |
| y: torch.Tensor, |
| x: torch.Tensor, |
| iterations: int = 2, |
| eps: float = 1e-10, |
| batch_size: int = 200, |
| ): |
| r"""Expectation maximization algorithm, for refining source separation |
| estimates. |
| |
| This algorithm allows to make source separation results better by |
| enforcing multichannel consistency for the estimates. This usually means |
| a better perceptual quality in terms of spatial artifacts. |
| |
| The implementation follows the details presented in [1]_, taking |
| inspiration from the original EM algorithm proposed in [2]_ and its |
| weighted refinement proposed in [3]_, [4]_. |
| It works by iteratively: |
| |
| * Re-estimate source parameters (power spectral densities and spatial |
| covariance matrices) through :func:`get_local_gaussian_model`. |
| |
| * Separate again the mixture with the new parameters by first computing |
| the new modelled mixture covariance matrices with :func:`get_mix_model`, |
| prepare the Wiener filters through :func:`wiener_gain` and apply them |
| with :func:`apply_filter``. |
| |
| References |
| ---------- |
| .. [1] S. Uhlich and M. Porcu and F. Giron and M. Enenkl and T. Kemp and |
| N. Takahashi and Y. Mitsufuji, "Improving music source separation based |
| on deep neural networks through data augmentation and network |
| blending." 2017 IEEE International Conference on Acoustics, Speech |
| and Signal Processing (ICASSP). IEEE, 2017. |
| |
| .. [2] N.Q. Duong and E. Vincent and R.Gribonval. "Under-determined |
| reverberant audio source separation using a full-rank spatial |
| covariance model." IEEE Transactions on Audio, Speech, and Language |
| Processing 18.7 (2010): 1830-1840. |
| |
| .. [3] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel audio source |
| separation with deep neural networks." IEEE/ACM Transactions on Audio, |
| Speech, and Language Processing 24.9 (2016): 1652-1664. |
| |
| .. [4] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel music |
| separation with deep neural networks." 2016 24th European Signal |
| Processing Conference (EUSIPCO). IEEE, 2016. |
| |
| .. [5] A. Liutkus and R. Badeau and G. Richard "Kernel additive models for |
| source separation." IEEE Transactions on Signal Processing |
| 62.16 (2014): 4298-4310. |
| |
| Args: |
| y (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2, nb_sources)] |
| initial estimates for the sources |
| x (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2)] |
| complex STFT of the mixture signal |
| iterations (int): [scalar] |
| number of iterations for the EM algorithm. |
| eps (float or None): [scalar] |
| The epsilon value to use for regularization and filters. |
| |
| Returns: |
| y (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2, nb_sources)] |
| estimated sources after iterations |
| v (Tensor): [shape=(nb_frames, nb_bins, nb_sources)] |
| estimated power spectral densities |
| R (Tensor): [shape=(nb_bins, nb_channels, nb_channels, 2, nb_sources)] |
| estimated spatial covariance matrices |
| |
| Notes: |
| * You need an initial estimate for the sources to apply this |
| algorithm. This is precisely what the :func:`wiener` function does. |
| * This algorithm *is not* an implementation of the "exact" EM |
| proposed in [1]_. In particular, it does compute the posterior |
| covariance matrices the same (exact) way. Instead, it uses the |
| simplified approximate scheme initially proposed in [5]_ and further |
| refined in [3]_, [4]_, that boils down to just take the empirical |
| covariance of the recent source estimates, followed by a weighted |
| average for the update of the spatial covariance matrix. It has been |
| empirically demonstrated that this simplified algorithm is more |
| robust for music separation. |
| |
| Warning: |
| It is *very* important to make sure `x.dtype` is `torch.float64` |
| if you want double precision, because this function will **not** |
| do such conversion for you from `torch.complex32`, in case you want the |
| smaller RAM usage on purpose. |
| |
| It is usually always better in terms of quality to have double |
| precision, by e.g. calling :func:`expectation_maximization` |
| with ``x.to(torch.float64)``. |
| """ |
| |
| (nb_frames, nb_bins, nb_channels) = x.shape[:-1] |
| nb_sources = y.shape[-1] |
|
|
| regularization = torch.cat( |
| ( |
| torch.eye(nb_channels, dtype=x.dtype, device=x.device)[..., None], |
| torch.zeros((nb_channels, nb_channels, 1), dtype=x.dtype, device=x.device), |
| ), |
| dim=2, |
| ) |
| regularization = torch.sqrt(torch.as_tensor(eps)) * ( |
| regularization[None, None, ...].expand((-1, nb_bins, -1, -1, -1)) |
| ) |
|
|
| |
| R = [ |
| torch.zeros((nb_bins, nb_channels, nb_channels, 2), dtype=x.dtype, device=x.device) |
| for j in range(nb_sources) |
| ] |
| weight: torch.Tensor = torch.zeros((nb_bins,), dtype=x.dtype, device=x.device) |
|
|
| v: torch.Tensor = torch.zeros((nb_frames, nb_bins, nb_sources), dtype=x.dtype, device=x.device) |
| for it in range(iterations): |
| |
| |
|
|
| |
| v = torch.mean(torch.abs(y[..., 0, :]) ** 2 + torch.abs(y[..., 1, :]) ** 2, dim=-2) |
|
|
| |
| for j in range(nb_sources): |
| R[j] = torch.tensor(0.0, device=x.device) |
| weight = torch.tensor(eps, device=x.device) |
| pos: int = 0 |
| batch_size = batch_size if batch_size else nb_frames |
| while pos < nb_frames: |
| t = torch.arange(pos, min(nb_frames, pos + batch_size)) |
| pos = int(t[-1]) + 1 |
|
|
| R[j] = R[j] + torch.sum(_covariance(y[t, ..., j]), dim=0) |
| weight = weight + torch.sum(v[t, ..., j], dim=0) |
| R[j] = R[j] / weight[..., None, None, None] |
| weight = torch.zeros_like(weight) |
|
|
| |
| if y.requires_grad: |
| y = y.clone() |
|
|
| pos = 0 |
| while pos < nb_frames: |
| t = torch.arange(pos, min(nb_frames, pos + batch_size)) |
| pos = int(t[-1]) + 1 |
|
|
| y[t, ...] = torch.tensor(0.0, device=x.device, dtype=x.dtype) |
|
|
| |
| Cxx = regularization |
| for j in range(nb_sources): |
| Cxx = Cxx + (v[t, ..., j, None, None, None] * R[j][None, ...].clone()) |
|
|
| |
| inv_Cxx = _invert(Cxx) |
|
|
| |
| for j in range(nb_sources): |
|
|
| |
| gain = torch.zeros_like(inv_Cxx) |
|
|
| |
| indices = torch.cartesian_prod( |
| torch.arange(nb_channels), |
| torch.arange(nb_channels), |
| torch.arange(nb_channels), |
| ) |
| for index in indices: |
| gain[:, :, index[0], index[1], :] = _mul_add( |
| R[j][None, :, index[0], index[2], :].clone(), |
| inv_Cxx[:, :, index[2], index[1], :], |
| gain[:, :, index[0], index[1], :], |
| ) |
| gain = gain * v[t, ..., None, None, None, j] |
|
|
| |
| for i in range(nb_channels): |
| y[t, ..., j] = _mul_add(gain[..., i, :], x[t, ..., i, None, :], y[t, ..., j]) |
|
|
| return y, v, R |
|
|
|
|
| def wiener( |
| targets_spectrograms: torch.Tensor, |
| mix_stft: torch.Tensor, |
| iterations: int = 1, |
| softmask: bool = False, |
| residual: bool = False, |
| scale_factor: float = 10.0, |
| eps: float = 1e-10, |
| ): |
| """Wiener-based separation for multichannel audio. |
| |
| The method uses the (possibly multichannel) spectrograms of the |
| sources to separate the (complex) Short Term Fourier Transform of the |
| mix. Separation is done in a sequential way by: |
| |
| * Getting an initial estimate. This can be done in two ways: either by |
| directly using the spectrograms with the mixture phase, or |
| by using a softmasking strategy. This initial phase is controlled |
| by the `softmask` flag. |
| |
| * If required, adding an additional residual target as the mix minus |
| all targets. |
| |
| * Refinining these initial estimates through a call to |
| :func:`expectation_maximization` if the number of iterations is nonzero. |
| |
| This implementation also allows to specify the epsilon value used for |
| regularization. It is based on [1]_, [2]_, [3]_, [4]_. |
| |
| References |
| ---------- |
| .. [1] S. Uhlich and M. Porcu and F. Giron and M. Enenkl and T. Kemp and |
| N. Takahashi and Y. Mitsufuji, "Improving music source separation based |
| on deep neural networks through data augmentation and network |
| blending." 2017 IEEE International Conference on Acoustics, Speech |
| and Signal Processing (ICASSP). IEEE, 2017. |
| |
| .. [2] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel audio source |
| separation with deep neural networks." IEEE/ACM Transactions on Audio, |
| Speech, and Language Processing 24.9 (2016): 1652-1664. |
| |
| .. [3] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel music |
| separation with deep neural networks." 2016 24th European Signal |
| Processing Conference (EUSIPCO). IEEE, 2016. |
| |
| .. [4] A. Liutkus and R. Badeau and G. Richard "Kernel additive models for |
| source separation." IEEE Transactions on Signal Processing |
| 62.16 (2014): 4298-4310. |
| |
| Args: |
| targets_spectrograms (Tensor): spectrograms of the sources |
| [shape=(nb_frames, nb_bins, nb_channels, nb_sources)]. |
| This is a nonnegative tensor that is |
| usually the output of the actual separation method of the user. The |
| spectrograms may be mono, but they need to be 4-dimensional in all |
| cases. |
| mix_stft (Tensor): [shape=(nb_frames, nb_bins, nb_channels, complex=2)] |
| STFT of the mixture signal. |
| iterations (int): [scalar] |
| number of iterations for the EM algorithm |
| softmask (bool): Describes how the initial estimates are obtained. |
| * if `False`, then the mixture phase will directly be used with the |
| spectrogram as initial estimates. |
| * if `True`, initial estimates are obtained by multiplying the |
| complex mix element-wise with the ratio of each target spectrogram |
| with the sum of them all. This strategy is better if the model are |
| not really good, and worse otherwise. |
| residual (bool): if `True`, an additional target is created, which is |
| equal to the mixture minus the other targets, before application of |
| expectation maximization |
| eps (float): Epsilon value to use for computing the separations. |
| This is used whenever division with a model energy is |
| performed, i.e. when softmasking and when iterating the EM. |
| It can be understood as the energy of the additional white noise |
| that is taken out when separating. |
| |
| Returns: |
| Tensor: shape=(nb_frames, nb_bins, nb_channels, complex=2, nb_sources) |
| STFT of estimated sources |
| |
| Notes: |
| * Be careful that you need *magnitude spectrogram estimates* for the |
| case `softmask==False`. |
| * `softmask=False` is recommended |
| * The epsilon value will have a huge impact on performance. If it's |
| large, only the parts of the signal with a significant energy will |
| be kept in the sources. This epsilon then directly controls the |
| energy of the reconstruction error. |
| |
| Warning: |
| As in :func:`expectation_maximization`, we recommend converting the |
| mixture `x` to double precision `torch.float64` *before* calling |
| :func:`wiener`. |
| """ |
| if softmask: |
| |
| |
| y = ( |
| mix_stft[..., None] |
| * ( |
| targets_spectrograms |
| / (eps + torch.sum(targets_spectrograms, dim=-1, keepdim=True).to(mix_stft.dtype)) |
| )[..., None, :] |
| ) |
| else: |
| |
| |
| angle = atan2(mix_stft[..., 1], mix_stft[..., 0])[..., None] |
| nb_sources = targets_spectrograms.shape[-1] |
| y = torch.zeros( |
| mix_stft.shape + (nb_sources,), dtype=mix_stft.dtype, device=mix_stft.device |
| ) |
| y[..., 0, :] = targets_spectrograms * torch.cos(angle) |
| y[..., 1, :] = targets_spectrograms * torch.sin(angle) |
|
|
| if residual: |
| |
| |
| y = torch.cat([y, mix_stft[..., None] - y.sum(dim=-1, keepdim=True)], dim=-1) |
|
|
| if iterations == 0: |
| return y |
|
|
| |
| |
| max_abs = torch.max( |
| torch.as_tensor(1.0, dtype=mix_stft.dtype, device=mix_stft.device), |
| torch.sqrt(_norm(mix_stft)).max() / scale_factor, |
| ) |
|
|
| mix_stft = mix_stft / max_abs |
| y = y / max_abs |
|
|
| |
| y = expectation_maximization(y, mix_stft, iterations, eps=eps)[0] |
|
|
| |
| y = y * max_abs |
| return y |
|
|
|
|
| def _covariance(y_j): |
| """ |
| Compute the empirical covariance for a source. |
| |
| Args: |
| y_j (Tensor): complex stft of the source. |
| [shape=(nb_frames, nb_bins, nb_channels, 2)]. |
| |
| Returns: |
| Cj (Tensor): [shape=(nb_frames, nb_bins, nb_channels, nb_channels, 2)] |
| just y_j * conj(y_j.T): empirical covariance for each TF bin. |
| """ |
| (nb_frames, nb_bins, nb_channels) = y_j.shape[:-1] |
| Cj = torch.zeros( |
| (nb_frames, nb_bins, nb_channels, nb_channels, 2), |
| dtype=y_j.dtype, |
| device=y_j.device, |
| ) |
| indices = torch.cartesian_prod(torch.arange(nb_channels), torch.arange(nb_channels)) |
| for index in indices: |
| Cj[:, :, index[0], index[1], :] = _mul_add( |
| y_j[:, :, index[0], :], |
| _conj(y_j[:, :, index[1], :]), |
| Cj[:, :, index[0], index[1], :], |
| ) |
| return Cj |
|
|