| import torch
|
| from scipy.signal import get_window
|
|
|
| from torch import nn
|
|
|
| '''
|
| class LambdaOverlapAdd(torch.nn.Module):
|
| """Overlap-add with lambda transform on segments.
|
|
|
| Segment input signal, apply lambda function (a neural network for example)
|
| and combine with OLA.
|
|
|
| Args:
|
| nnet (callable): Function to apply to each segment.
|
| n_src (int): Number of sources in the output of nnet.
|
| window_size (int): Size of segmenting window.
|
| hop_size (int): Segmentation hop size.
|
| window (str): Name of the window (see scipy.signal.get_window) used
|
| for the synthesis.
|
| reorder_chunks (bool): Whether to reorder each consecutive segment.
|
| This might be useful when `nnet` is permutation invariant, as
|
| source assignements might change output channel from one segment
|
| to the next (in classic speech separation for example).
|
| Reordering is performed based on the correlation between
|
| the overlapped part of consecutive segment.
|
|
|
| Examples:
|
| >>> from asteroid_test import ConvTasNet
|
| >>> nnet = ConvTasNet(n_src=2)
|
| >>> continuous_nnet = LambdaOverlapAdd(
|
| >>> nnet=nnet,
|
| >>> n_src=2,
|
| >>> window_size=64000,
|
| >>> hop_size=None,
|
| >>> window="hanning",
|
| >>> reorder_chunks=True,
|
| >>> enable_grad=False,
|
| >>> )
|
| >>> wav = torch.randn(1, 1, 500000)
|
| >>> out_wavs = continuous_nnet.forward(wav)
|
| """
|
|
|
| def __init__(
|
| self,
|
| nnet,
|
| n_src,
|
| window_size,
|
| hop_size=None,
|
| window="hanning",
|
| reorder_chunks=True,
|
| enable_grad=False,
|
| ):
|
| super().__init__()
|
| assert window_size % 2 == 0, "Window size must be even"
|
|
|
| self.nnet = nnet
|
| self.window_size = window_size
|
| self.hop_size = hop_size if hop_size is not None else window_size // 2
|
| self.n_src = n_src
|
|
|
| if window:
|
| window = get_window(window, self.window_size).astype("float32")
|
| window = torch.from_numpy(window)
|
| self.use_window = True
|
| else:
|
| self.use_window = False
|
|
|
| self.register_buffer("window", window)
|
| self.reorder_chunks = reorder_chunks
|
| self.enable_grad = enable_grad
|
|
|
| def ola_forward(self, x):
|
| """Heart of the class: segment signal, apply func, combine with OLA."""
|
|
|
| assert x.ndim == 3
|
|
|
| batch, channels, n_frames = x.size()
|
| # Overlap and add:
|
| # [batch, chans, n_frames] -> [batch, chans, win_size, n_chunks]
|
| unfolded = torch.nn.functional.unfold(
|
| x.unsqueeze(-1),
|
| kernel_size=(self.window_size, 1),
|
| padding=(self.window_size, 0),
|
| stride=(self.hop_size, 1),
|
| )
|
|
|
| out = []
|
| n_chunks = unfolded.shape[-1]
|
| for frame_idx in range(n_chunks): # for loop to spare memory
|
| frame = self.nnet(unfolded[..., frame_idx])
|
| # user must handle multichannel by reshaping to batch
|
| if frame_idx == 0:
|
| assert frame.ndim == 3, "nnet should return (batch, n_src, time)"
|
| assert frame.shape[1] == self.n_src, "nnet should return (batch, n_src, time)"
|
| frame = frame.reshape(batch * self.n_src, -1)
|
|
|
| if frame_idx != 0 and self.reorder_chunks:
|
| # we determine best perm based on xcorr with previous sources
|
| frame = _reorder_sources(
|
| frame, out[-1], self.n_src, self.window_size, self.hop_size
|
| )
|
|
|
| if self.use_window:
|
| frame = frame * self.window
|
| else:
|
| frame = frame / (self.window_size / self.hop_size)
|
| out.append(frame)
|
|
|
| out = torch.stack(out).reshape(n_chunks, batch * self.n_src, self.window_size)
|
| out = out.permute(1, 2, 0)
|
|
|
| out = torch.nn.functional.fold(
|
| out,
|
| (n_frames, 1),
|
| kernel_size=(self.window_size, 1),
|
| padding=(self.window_size, 0),
|
| stride=(self.hop_size, 1),
|
| )
|
| return out.squeeze(-1).reshape(batch, self.n_src, -1)
|
|
|
| def forward(self, x):
|
| """Forward module: segment signal, apply func, combine with OLA.
|
|
|
| Args:
|
| x (:class:`torch.Tensor`): waveform signal of shape (batch, 1, time).
|
|
|
| Returns:
|
| :class:`torch.Tensor`: The output of the lambda OLA.
|
| """
|
| # Here we can do the reshaping
|
| with torch.autograd.set_grad_enabled(self.enable_grad):
|
| olad = self.ola_forward(x)
|
| return olad
|
|
|
|
|
| def _reorder_sources(
|
| current: torch.FloatTensor,
|
| previous: torch.FloatTensor,
|
| n_src: int,
|
| window_size: int,
|
| hop_size: int,
|
| ):
|
| """
|
| Reorder sources in current chunk to maximize correlation with previous chunk.
|
| Used for Continuous Source Separation. Standard dsp correlation is used
|
| for reordering.
|
|
|
|
|
| Args:
|
| current (:class:`torch.Tensor`): current chunk, tensor
|
| of shape (batch, n_src, window_size)
|
| previous (:class:`torch.Tensor`): previous chunk, tensor
|
| of shape (batch, n_src, window_size)
|
| n_src (:class:`int`): number of sources.
|
| window_size (:class:`int`): window_size, equal to last dimension of
|
| both current and previous.
|
| hop_size (:class:`int`): hop_size between current and previous tensors.
|
|
|
| Returns:
|
| current:
|
|
|
| """
|
| batch, frames = current.size()
|
| current = current.reshape(-1, n_src, frames)
|
| previous = previous.reshape(-1, n_src, frames)
|
|
|
| overlap_f = window_size - hop_size
|
|
|
| def reorder_func(x, y):
|
| x = x[..., :overlap_f]
|
| y = y[..., -overlap_f:]
|
| # Mean normalization
|
| x = x - x.mean(-1, keepdim=True)
|
| y = y - y.mean(-1, keepdim=True)
|
| # Negative mean Correlation
|
| return -torch.sum(x.unsqueeze(1) * y.unsqueeze(2), dim=-1)
|
|
|
| # We maximize correlation-like between previous and current.
|
| pit = PITLossWrapper(reorder_func)
|
| current = pit(current, previous, return_est=True)[1]
|
| return current.reshape(batch, frames)
|
| '''
|
|
|
|
|
| class DualPathProcessing(nn.Module):
|
| """Perform Dual-Path processing via overlap-add as in DPRNN [1].
|
|
|
| Args:
|
| chunk_size (int): Size of segmenting window.
|
| hop_size (int): segmentation hop size.
|
|
|
| References:
|
| [1] "Dual-path RNN: efficient long sequence modeling for
|
| time-domain single-channel speech separation", Yi Luo, Zhuo Chen
|
| and Takuya Yoshioka. https://arxiv.org/abs/1910.06379
|
| """
|
|
|
| def __init__(self, chunk_size, hop_size):
|
| super(DualPathProcessing, self).__init__()
|
| self.chunk_size = chunk_size
|
| self.hop_size = hop_size
|
| self.n_orig_frames = None
|
|
|
| def unfold(self, x):
|
| """Unfold the feature tensor from
|
|
|
| (batch, channels, time) to (batch, channels, chunk_size, n_chunks).
|
|
|
| Args:
|
| x: (:class:`torch.Tensor`): feature tensor of shape (batch, channels, time).
|
|
|
| Returns:
|
| x: (:class:`torch.Tensor`): spliced feature tensor of shape
|
| (batch, channels, chunk_size, n_chunks).
|
|
|
| """
|
|
|
| batch, chan, frames = x.size()
|
| assert x.ndim == 3
|
| self.n_orig_frames = x.shape[-1]
|
| unfolded = torch.nn.functional.unfold(
|
| x.unsqueeze(-1),
|
| kernel_size=(self.chunk_size, 1),
|
| padding=(self.chunk_size, 0),
|
| stride=(self.hop_size, 1),
|
| )
|
|
|
| return unfolded.reshape(
|
| batch, chan, self.chunk_size, -1
|
| )
|
|
|
| def fold(self, x, output_size=None):
|
| """Folds back the spliced feature tensor.
|
|
|
| Input shape (batch, channels, chunk_size, n_chunks) to original shape
|
| (batch, channels, time) using overlap-add.
|
|
|
| Args:
|
| x: (:class:`torch.Tensor`): spliced feature tensor of shape
|
| (batch, channels, chunk_size, n_chunks).
|
| output_size: (int, optional): sequence length of original feature tensor.
|
| If None, the original length cached by the previous call of `unfold`
|
| will be used.
|
|
|
| Returns:
|
| x: (:class:`torch.Tensor`): feature tensor of shape (batch, channels, time).
|
|
|
| .. note:: `fold` caches the original length of the pr
|
|
|
| """
|
| output_size = output_size if output_size is not None else self.n_orig_frames
|
|
|
| batch, chan, chunk_size, n_chunks = x.size()
|
| to_unfold = x.reshape(batch, chan * self.chunk_size, n_chunks)
|
| x = torch.nn.functional.fold(
|
| to_unfold,
|
| (output_size, 1),
|
| kernel_size=(self.chunk_size, 1),
|
| padding=(self.chunk_size, 0),
|
| stride=(self.hop_size, 1),
|
| )
|
|
|
| x /= self.chunk_size / self.hop_size
|
|
|
| return x.reshape(batch, chan, self.n_orig_frames)
|
|
|
| @staticmethod
|
| def intra_process(x, module):
|
| """Performs intra-chunk processing.
|
|
|
| Args:
|
| x (:class:`torch.Tensor`): spliced feature tensor of shape
|
| (batch, channels, chunk_size, n_chunks).
|
| module (:class:`torch.nn.Module`): module one wish to apply to each chunk
|
| of the spliced feature tensor.
|
|
|
|
|
| Returns:
|
| x (:class:`torch.Tensor`): processed spliced feature tensor of shape
|
| (batch, channels, chunk_size, n_chunks).
|
|
|
| .. note:: the module should have the channel first convention and accept
|
| a 3D tensor of shape (batch, channels, time).
|
| """
|
|
|
|
|
| batch, channels, chunk_size, n_chunks = x.size()
|
|
|
| x = x.transpose(1, -1).reshape(batch * n_chunks, chunk_size, channels).transpose(1, -1)
|
| x = module(x)
|
| x = x.reshape(batch, n_chunks, channels, chunk_size).transpose(1, -1).transpose(1, 2)
|
| return x
|
|
|
| @staticmethod
|
| def inter_process(x, module):
|
| """Performs inter-chunk processing.
|
|
|
| Args:
|
| x (:class:`torch.Tensor`): spliced feature tensor of shape
|
| (batch, channels, chunk_size, n_chunks).
|
| module (:class:`torch.nn.Module`): module one wish to apply between
|
| each chunk of the spliced feature tensor.
|
|
|
|
|
| Returns:
|
| x (:class:`torch.Tensor`): processed spliced feature tensor of shape
|
| (batch, channels, chunk_size, n_chunks).
|
|
|
| .. note:: the module should have the channel first convention and accept
|
| a 3D tensor of shape (batch, channels, time).
|
| """
|
|
|
| batch, channels, chunk_size, n_chunks = x.size()
|
| x = x.transpose(1, 2).reshape(batch * chunk_size, channels, n_chunks)
|
| x = module(x)
|
| x = x.reshape(batch, chunk_size, channels, n_chunks).transpose(1, 2)
|
| return x
|
|
|