# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Dict, Optional, Tuple import numpy as np import torch from nemo.collections.asr.parts.preprocessing.features import make_seq_mask_like from nemo.collections.asr.parts.utils.audio_utils import db2mag, wrap_to_pi from nemo.core.classes import NeuralModule, typecheck from nemo.core.neural_types import FloatType, LengthsType, NeuralType, SpectrogramType from nemo.utils import logging from nemo.utils.decorators import experimental try: import torchaudio HAVE_TORCHAUDIO = True except ModuleNotFoundError: HAVE_TORCHAUDIO = False __all__ = [ 'MaskEstimatorRNN', 'MaskReferenceChannel', 'MaskBasedBeamformer', 'MaskBasedDereverbWPE', ] @experimental class SpectrogramToMultichannelFeatures(NeuralModule): """Convert a complex-valued multi-channel spectrogram to multichannel features. Args: num_subbands: Expected number of subbands in the input signal num_input_channels: Optional, provides the number of channels of the input signal. Used to infer the number of output channels. magnitude_reduction: Reduction across channels. Default `None`, will calculate magnitude of each channel. use_ipd: Use inter-channel phase difference (IPD). mag_normalization: Normalization for magnitude features ipd_normalization: Normalization for IPD features """ def __init__( self, num_subbands: int, num_input_channels: Optional[int] = None, mag_reduction: Optional[str] = 'rms', use_ipd: bool = False, mag_normalization: Optional[str] = None, ipd_normalization: Optional[str] = None, ): super().__init__() self.mag_reduction = mag_reduction self.use_ipd = use_ipd # TODO: normalization if mag_normalization is not None: raise NotImplementedError(f'Unknown magnitude normalization {mag_normalization}') self.mag_normalization = mag_normalization if ipd_normalization is not None: raise NotImplementedError(f'Unknown ipd normalization {ipd_normalization}') self.ipd_normalization = ipd_normalization if self.use_ipd: self._num_features = 2 * num_subbands self._num_channels = num_input_channels else: self._num_features = num_subbands self._num_channels = num_input_channels if self.mag_reduction is None else 1 @property def input_types(self) -> Dict[str, NeuralType]: """Returns definitions of module output ports. """ return { "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), "input_length": NeuralType(('B',), LengthsType()), } @property def output_types(self) -> Dict[str, NeuralType]: """Returns definitions of module output ports. """ return { "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), "output_length": NeuralType(('B',), LengthsType()), } @property def num_features(self) -> int: """Configured number of features """ return self._num_features @property def num_channels(self) -> int: """Configured number of channels """ if self._num_channels is not None: return self._num_channels else: raise ValueError( 'Num channels is not configured. To configure this, `num_input_channels` ' 'must be provided when constructing the object.' ) @typecheck() def forward(self, input: torch.Tensor, input_length: torch.Tensor) -> torch.Tensor: """Convert input batch of C-channel spectrograms into a batch of time-frequency features with dimension num_feat. The output number of channels may be the same as input, or reduced to 1, e.g., if averaging over magnitude and not appending individual IPDs. Args: input: Spectrogram for C channels with F subbands and N time frames, (B, C, F, N) input_length: Length of valid entries along the time dimension, shape (B,) Returns: num_feat_channels channels with num_feat features, shape (B, num_feat_channels, num_feat, N) """ # Magnitude spectrum if self.mag_reduction is None: mag = torch.abs(input) elif self.mag_reduction == 'abs_mean': mag = torch.abs(torch.mean(input, axis=1, keepdim=True)) elif self.mag_reduction == 'mean_abs': mag = torch.mean(torch.abs(input), axis=1, keepdim=True) elif self.mag_reduction == 'rms': mag = torch.sqrt(torch.mean(torch.abs(input) ** 2, axis=1, keepdim=True)) else: raise ValueError(f'Unexpected magnitude reduction {self.mag_reduction}') if self.mag_normalization is not None: mag = self.mag_normalization(mag) features = mag if self.use_ipd: # Calculate IPD relative to average spec spec_mean = torch.mean(input, axis=1, keepdim=True) ipd = torch.angle(input) - torch.angle(spec_mean) # Modulo to [-pi, pi] ipd = wrap_to_pi(ipd) if self.ipd_normalization is not None: ipd = self.ipd_normalization(ipd) # Concatenate to existing features features = torch.cat([features.expand(ipd.shape), ipd], axis=2) if self._num_channels is not None and features.size(1) != self._num_channels: raise RuntimeError( f'Number of channels in features {features.size(1)} is different than the configured number of channels {self._num_channels}' ) return features, input_length class MaskEstimatorRNN(NeuralModule): """Estimate `num_outputs` masks from the input spectrogram using stacked RNNs and projections. The module is structured as follows: input --> spatial features --> input projection --> --> stacked RNNs --> output projection for each output --> sigmoid Reference: Multi-microphone neural speech separation for far-field multi-talker speech recognition (https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=8462081) Args: num_outputs: Number of output masks to estimate num_subbands: Number of subbands of the input spectrogram num_features: Number of features after the input projections num_layers: Number of RNN layers num_hidden_features: Number of hidden features in RNN layers num_input_channels: Number of input channels dropout: If non-zero, introduces dropout on the outputs of each RNN layer except the last layer, with dropout probability equal to `dropout`. Default: 0 bidirectional: If `True`, use bidirectional RNN. rnn_type: Type of RNN, either `lstm` or `gru`. Default: `lstm` mag_reduction: Channel-wise reduction for magnitude features use_ipd: Use inter-channel phase difference (IPD) features """ def __init__( self, num_outputs: int, num_subbands: int, num_features: int = 1024, num_layers: int = 3, num_hidden_features: Optional[int] = None, num_input_channels: Optional[int] = None, dropout: float = 0, bidirectional=True, rnn_type: str = 'lstm', mag_reduction: str = 'rms', use_ipd: bool = None, ): super().__init__() if num_hidden_features is None: num_hidden_features = num_features self.features = SpectrogramToMultichannelFeatures( num_subbands=num_subbands, num_input_channels=num_input_channels, mag_reduction=mag_reduction, use_ipd=use_ipd, ) self.input_projection = torch.nn.Linear( in_features=self.features.num_features * self.features.num_channels, out_features=num_features ) if rnn_type == 'lstm': self.rnn = torch.nn.LSTM( input_size=num_features, hidden_size=num_hidden_features, num_layers=num_layers, batch_first=True, dropout=dropout, bidirectional=bidirectional, ) elif rnn_type == 'gru': self.rnn = torch.nn.GRU( input_size=num_features, hidden_size=num_hidden_features, num_layers=num_layers, batch_first=True, dropout=dropout, bidirectional=bidirectional, ) else: raise ValueError(f'Unknown rnn_type: {rnn_type}') # Each output shares the RNN and has a separate projection self.output_projections = torch.nn.ModuleList( [ torch.nn.Linear( in_features=2 * num_features if bidirectional else num_features, out_features=num_subbands ) for _ in range(num_outputs) ] ) self.output_nonlinearity = torch.nn.Sigmoid() @property def input_types(self) -> Dict[str, NeuralType]: """Returns definitions of module output ports. """ return { "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), "input_length": NeuralType(('B',), LengthsType()), } @property def output_types(self) -> Dict[str, NeuralType]: """Returns definitions of module output ports. """ return { "output": NeuralType(('B', 'C', 'D', 'T'), FloatType()), "output_length": NeuralType(('B',), LengthsType()), } @typecheck() def forward(self, input: torch.Tensor, input_length: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Estimate `num_outputs` masks from the input spectrogram. Args: input: C-channel input, shape (B, C, F, N) input_length: Length of valid entries along the time dimension, shape (B,) Returns: Returns `num_outputs` masks in a tensor, shape (B, num_outputs, F, N), and output length with shape (B,) """ input, _ = self.features(input=input, input_length=input_length) B, num_feature_channels, num_features, N = input.shape # (B, num_feat_channels, num_feat, N) -> (B, N, num_feat_channels, num_feat) input = input.permute(0, 3, 1, 2) # (B, N, num_feat_channels, num_feat) -> (B, N, num_feat_channels * num_features) input = input.view(B, N, -1) # Apply projection on num_feat input = self.input_projection(input) # Apply RNN on the input sequence input_packed = torch.nn.utils.rnn.pack_padded_sequence( input, input_length.cpu(), batch_first=True, enforce_sorted=False ).to(input.device) self.rnn.flatten_parameters() input_packed, _ = self.rnn(input_packed) input, input_length = torch.nn.utils.rnn.pad_packed_sequence(input_packed, batch_first=True) input_length = input_length.to(input.device) # Create `num_outputs` masks output = [] for output_projection in self.output_projections: # Output projection mask = output_projection(input) mask = self.output_nonlinearity(mask) # Back to the original format # (B, N, F) -> (B, F, N) mask = mask.transpose(2, 1) # Append to the output output.append(mask) # Stack along channel dimension to get (B, M, F, N) output = torch.stack(output, axis=1) # Mask frames beyond input length length_mask: torch.Tensor = make_seq_mask_like( lengths=input_length, like=output, time_dim=-1, valid_ones=False ) output = output.masked_fill(length_mask, 0.0) return output, input_length class MaskReferenceChannel(NeuralModule): """A simple mask processor which applies mask on ref_channel of the input signal. Args: ref_channel: Index of the reference channel. mask_min_db: Threshold mask to a minimal value before applying it, defaults to -200dB mask_max_db: Threshold mask to a maximal value before applying it, defaults to 0dB """ def __init__(self, ref_channel: int = 0, mask_min_db: float = -200, mask_max_db: float = 0): super().__init__() self.ref_channel = ref_channel # Mask thresholding self.mask_min = db2mag(mask_min_db) self.mask_max = db2mag(mask_max_db) @property def input_types(self) -> Dict[str, NeuralType]: """Returns definitions of module output ports. """ return { "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), "input_length": NeuralType(('B',), LengthsType()), "mask": NeuralType(('B', 'C', 'D', 'T'), FloatType()), } @property def output_types(self) -> Dict[str, NeuralType]: """Returns definitions of module output ports. """ return { "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), "output_length": NeuralType(('B',), LengthsType()), } @typecheck() def forward( self, input: torch.Tensor, input_length: torch.Tensor, mask: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Apply mask on `ref_channel` of the input signal. This can be used to generate multi-channel output. If `mask` has `M` channels, the output will have `M` channels as well. Args: input: Input signal complex-valued spectrogram, shape (B, C, F, N) input_length: Length of valid entries along the time dimension, shape (B,) mask: Mask for M outputs, shape (B, M, F, N) Returns: M-channel output complex-valed spectrogram with shape (B, M, F, N) """ # Apply thresholds mask = torch.clamp(mask, min=self.mask_min, max=self.mask_max) # Apply each output mask on the ref channel output = mask * input[:, self.ref_channel : self.ref_channel + 1, ...] return output, input_length class MaskBasedBeamformer(NeuralModule): """Multi-channel processor using masks to estimate signal statistics. Args: filter_type: string denoting the type of the filter. Defaults to `mvdr` ref_channel: reference channel for processing mask_min_db: Threshold mask to a minimal value before applying it, defaults to -200dB mask_max_db: Threshold mask to a maximal value before applying it, defaults to 0dB """ def __init__( self, filter_type: str = 'mvdr_souden', ref_channel: int = 0, mask_min_db: float = -200, mask_max_db: float = 0, ): if not HAVE_TORCHAUDIO: logging.error('Could not import torchaudio. Some features might not work.') raise ModuleNotFoundError( "torchaudio is not installed but is necessary to instantiate a {self.__class__.__name__}" ) super().__init__() self.ref_channel = ref_channel self.filter_type = filter_type if self.filter_type == 'mvdr_souden': self.psd = torchaudio.transforms.PSD() self.filter = torchaudio.transforms.SoudenMVDR() else: raise ValueError(f'Unknown filter type {filter_type}') # Mask thresholding self.mask_min = db2mag(mask_min_db) self.mask_max = db2mag(mask_max_db) @property def input_types(self) -> Dict[str, NeuralType]: """Returns definitions of module output ports. """ return { "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), "input_length": NeuralType(('B',), LengthsType()), "mask": NeuralType(('B', 'C', 'D', 'T'), FloatType()), } @property def output_types(self) -> Dict[str, NeuralType]: """Returns definitions of module output ports. """ return { "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), "output_length": NeuralType(('B',), LengthsType()), } @typecheck() def forward(self, input: torch.Tensor, input_length: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: """Apply a mask-based beamformer to the input spectrogram. This can be used to generate multi-channel output. If `mask` has `M` channels, the output will have `M` channels as well. Args: input: Input signal complex-valued spectrogram, shape (B, C, F, N) input_length: Length of valid entries along the time dimension, shape (B,) mask: Mask for M output signals, shape (B, M, F, N) Returns: M-channel output signal complex-valued spectrogram, shape (B, M, F, N) """ # Apply threshold on the mask mask = torch.clamp(mask, min=self.mask_min, max=self.mask_max) # Length mask length_mask: torch.Tensor = make_seq_mask_like( lengths=input_length, like=mask[:, 0, ...], time_dim=-1, valid_ones=False ) # Use each mask to generate an output at ref_channel output = [] for m in range(mask.size(1)): # Prepare mask for the desired and the undesired signal mask_desired = mask[:, m, ...].masked_fill(length_mask, 0.0) mask_undesired = (1 - mask_desired).masked_fill(length_mask, 0.0) # Calculate PSDs psd_desired = self.psd(input, mask_desired) psd_undesired = self.psd(input, mask_undesired) # Apply filter output_m = self.filter(input, psd_desired, psd_undesired, reference_channel=self.ref_channel) output_m = output_m.masked_fill(length_mask, 0.0) # Save the current output (B, F, N) output.append(output_m) output = torch.stack(output, axis=1) return output, input_length class WPEFilter(NeuralModule): """A weighted prediction error filter. Given input signal, and expected power of the desired signal, this class estimates a multiple-input multiple-output prediction filter and returns the filtered signal. Currently, estimation of statistics and processing is performed in batch mode. Args: filter_length: Length of the prediction filter in frames, per channel prediction_delay: Prediction delay in frames diag_reg: Diagonal regularization for the correlation matrix Q, applied as diag_reg * trace(Q) + eps eps: Small positive constant for regularization References: - Yoshioka and Nakatani, Generalization of Multi-Channel Linear Prediction Methods for Blind MIMO Impulse Response Shortening, 2012 - Jukić et al, Group sparsity for MIMO speech dereverberation, 2015 """ def __init__( self, filter_length: int, prediction_delay: int, diag_reg: Optional[float] = 1e-8, eps: float = 1e-10 ): super().__init__() self.filter_length = filter_length self.prediction_delay = prediction_delay self.diag_reg = diag_reg self.eps = eps @property def input_types(self) -> Dict[str, NeuralType]: """Returns definitions of module output ports. """ return { "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), "power": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), "input_length": NeuralType(('B',), LengthsType(), optional=True), } @property def output_types(self) -> Dict[str, NeuralType]: """Returns definitions of module output ports. """ return { "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), "output_length": NeuralType(('B',), LengthsType(), optional=True), } @typecheck() def forward( self, input: torch.Tensor, power: torch.Tensor, input_length: Optional[torch.Tensor] = None ) -> torch.Tensor: """Given input and the predicted power for the desired signal, estimate the WPE filter and return the processed signal. Args: input: Input signal, shape (B, C, F, N) power: Predicted power of the desired signal, shape (B, C, F, N) input_length: Optional, length of valid frames in `input`. Defaults to `None` Returns: Tuple of (processed_signal, output_length). Processed signal has the same shape as the input signal (B, C, F, N), and the output length is the same as the input length. """ # Temporal weighting: average power over channels, shape (B, F, N) weight = torch.mean(power, dim=1) # Use inverse power as the weight weight = 1 / (weight + self.eps) # Multi-channel convolution matrix for each subband tilde_input = self.convtensor(input, filter_length=self.filter_length, delay=self.prediction_delay) # Estimate correlation matrices Q, R = self.estimate_correlations( input=input, weight=weight, tilde_input=tilde_input, input_length=input_length ) # Estimate prediction filter G = self.estimate_filter(Q=Q, R=R) # Apply prediction filter undesired_signal = self.apply_filter(filter=G, tilde_input=tilde_input) # Dereverberation desired_signal = input - undesired_signal if input_length is not None: # Mask padded frames length_mask: torch.Tensor = make_seq_mask_like( lengths=input_length, like=desired_signal, time_dim=-1, valid_ones=False ) desired_signal = desired_signal.masked_fill(length_mask, 0.0) return desired_signal, input_length @classmethod def convtensor( cls, x: torch.Tensor, filter_length: int, delay: int = 0, n_steps: Optional[int] = None ) -> torch.Tensor: """Create a tensor equivalent of convmtx_mc for each example in the batch. The input signal tensor `x` has shape (B, C, F, N). Convtensor returns a view of the input signal `x`. Note: We avoid reshaping the output to collapse channels and filter taps into a single dimension, e.g., (B, F, N, -1). In this way, the output is a view of the input, while an additional reshape would result in a contiguous array and more memory use. Args: x: input tensor, shape (B, C, F, N) filter_length: length of the filter, determines the shape of the convolution tensor delay: delay to add to the input signal `x` before constructing the convolution tensor n_steps: Optional, number of time steps to keep in the out. Defaults to the number of time steps in the input tensor. Returns: Return a convolutional tensor with shape (B, C, F, n_steps, filter_length) """ if x.ndim != 4: raise RuntimeError(f'Expecting a 4-D input. Received input with shape {x.shape}') B, C, F, N = x.shape if n_steps is None: # Keep the same length as the input signal n_steps = N # Pad temporal dimension x = torch.nn.functional.pad(x, (filter_length - 1 + delay, 0)) # Build Toeplitz-like matrix view by unfolding across time tilde_X = x.unfold(-1, filter_length, 1) # Trim to the set number of time steps tilde_X = tilde_X[:, :, :, :n_steps, :] return tilde_X @classmethod def permute_convtensor(cls, x: torch.Tensor) -> torch.Tensor: """Reshape and permute columns to convert the result of convtensor to be equal to convmtx_mc. This is used for verification purposes and it is not required to use the filter. Args: x: output of self.convtensor, shape (B, C, F, N, filter_length) Returns: Output has shape (B, F, N, C*filter_length) that corresponds to the layout of convmtx_mc. """ B, C, F, N, filter_length = x.shape # .view will not work, so a copy will have to be created with .reshape # That will result in more memory use, since we don't use a view of the original # multi-channel signal x = x.permute(0, 2, 3, 1, 4) x = x.reshape(B, F, N, C * filter_length) permute = [] for m in range(C): permute[m * filter_length : (m + 1) * filter_length] = m * filter_length + np.flip( np.arange(filter_length) ) return x[..., permute] def estimate_correlations( self, input: torch.Tensor, weight: torch.Tensor, tilde_input: torch.Tensor, input_length: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: """ Args: input: Input signal, shape (B, C, F, N) weight: Time-frequency weight, shape (B, F, N) tilde_input: Multi-channel convolution tensor, shape (B, C, F, N, filter_length) input_length: Length of each input example, shape (B) Returns: Returns a tuple of correlation matrices for each batch. Let `X` denote the input signal in a single subband, `tilde{X}` the corresponding multi-channel correlation matrix, and `w` the vector of weights. The first output is Q = tilde{X}^H * diag(w) * tilde{X} (1) for each (b, f). The matrix calculated in (1) has shape (C * filter_length, C * filter_length) The output is returned in a tensor with shape (B, F, C, filter_length, C, filter_length). The second output is R = tilde{X}^H * diag(w) * X (2) for each (b, f). The matrix calculated in (2) has shape (C * filter_length, C) The output is returned in a tensor with shape (B, F, C, filter_length, C). The last dimension corresponds to output channels. """ if input_length is not None: # Take only valid samples into account length_mask: torch.Tensor = make_seq_mask_like( lengths=input_length, like=weight, time_dim=-1, valid_ones=False ) weight = weight.masked_fill(length_mask, 0.0) # Calculate (1) # result: (B, F, C, filter_length, C, filter_length) Q = torch.einsum('bjfik,bmfin->bfjkmn', tilde_input.conj(), weight[:, None, :, :, None] * tilde_input) # Calculate (2) # result: (B, F, C, filter_length, C) R = torch.einsum('bjfik,bmfi->bfjkm', tilde_input.conj(), weight[:, None, :, :] * input) return Q, R def estimate_filter(self, Q: torch.Tensor, R: torch.Tensor) -> torch.Tensor: """Estimate the MIMO prediction filter as G(b,f) = Q(b,f) \ R(b,f) for each subband in each example in the batch (b, f). Args: Q: shape (B, F, C, filter_length, C, filter_length) R: shape (B, F, C, filter_length, C) Returns: Complex-valued prediction filter, shape (B, C, F, C, filter_length) """ B, F, C, filter_length, _, _ = Q.shape assert ( filter_length == self.filter_length ), f'Shape of Q {Q.shape} is not matching filter length {self.filter_length}' # Reshape to analytical dimensions for each (b, f) Q = Q.reshape(B, F, C * self.filter_length, C * filter_length) R = R.reshape(B, F, C * self.filter_length, C) # Diagonal regularization if self.diag_reg: # Regularization: diag_reg * trace(Q) + eps diag_reg = self.diag_reg * torch.diagonal(Q, dim1=-2, dim2=-1).sum(-1).real + self.eps # Apply regularization on Q Q = Q + torch.diag_embed(diag_reg.unsqueeze(-1) * torch.ones(Q.shape[-1], device=Q.device)) # Solve for the filter G = torch.linalg.solve(Q, R) # Reshape to desired representation: (B, F, input channels, filter_length, output channels) G = G.reshape(B, F, C, filter_length, C) # Move output channels to front: (B, output channels, F, input channels, filter_length) G = G.permute(0, 4, 1, 2, 3) return G def apply_filter( self, filter: torch.Tensor, input: Optional[torch.Tensor] = None, tilde_input: Optional[torch.Tensor] = None ) -> torch.Tensor: """Apply a prediction filter `filter` on the input `input` as output(b,f) = tilde{input(b,f)} * filter(b,f) If available, directly use the convolution matrix `tilde_input`. Args: input: Input signal, shape (B, C, F, N) tilde_input: Convolution matrix for the input signal, shape (B, C, F, N, filter_length) filter: Prediction filter, shape (B, C, F, C, filter_length) Returns: Multi-channel signal obtained by applying the prediction filter on the input signal, same shape as input (B, C, F, N) """ if input is None and tilde_input is None: raise RuntimeError(f'Both inputs cannot be None simultaneously.') if input is not None and tilde_input is not None: raise RuntimeError(f'Both inputs cannot be provided simultaneously.') if tilde_input is None: tilde_input = self.convtensor(input, filter_length=self.filter_length, delay=self.prediction_delay) # For each (batch, output channel, f, time step), sum across (input channel, filter tap) output = torch.einsum('bjfik,bmfjk->bmfi', tilde_input, filter) return output class MaskBasedDereverbWPE(NeuralModule): """Multi-channel linear prediction-based dereverberation using weighted prediction error for filter estimation. An optional mask to estimate the signal power can be provided. If a time-frequency mask is not provided, the algorithm corresponds to the conventional WPE algorithm. Args: filter_length: Length of the convolutional filter for each channel in frames. prediction_delay: Delay of the input signal for multi-channel linear prediction in frames. num_iterations: Number of iterations for reweighting mask_min_db: Threshold mask to a minimal value before applying it, defaults to -200dB mask_max_db: Threshold mask to a minimal value before applying it, defaults to 0dB diag_reg: Diagonal regularization for WPE eps: Small regularization constant References: - Kinoshita et al, Neural network-based spectrum estimation for online WPE dereverberation, 2017 - Yoshioka and Nakatani, Generalization of Multi-Channel Linear Prediction Methods for Blind MIMO Impulse Response Shortening, 2012 """ def __init__( self, filter_length: int, prediction_delay: int, num_iterations: int = 1, mask_min_db: float = -200, mask_max_db: float = 0, diag_reg: Optional[float] = 1e-8, eps: float = 1e-10, ): super().__init__() # Filter setup self.filter = WPEFilter( filter_length=filter_length, prediction_delay=prediction_delay, diag_reg=diag_reg, eps=eps ) self.num_iterations = num_iterations # Mask thresholding self.mask_min = db2mag(mask_min_db) self.mask_max = db2mag(mask_max_db) @property def input_types(self) -> Dict[str, NeuralType]: """Returns definitions of module output ports. """ return { "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), "input_length": NeuralType(('B',), LengthsType(), optional=True), "mask": NeuralType(('B', 'C', 'D', 'T'), FloatType(), optional=True), } @property def output_types(self) -> Dict[str, NeuralType]: """Returns definitions of module output ports. """ return { "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), "output_length": NeuralType(('B',), LengthsType(), optional=True), } @typecheck() def forward( self, input: torch.Tensor, input_length: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None ) -> torch.Tensor: """Given an input signal `input`, apply the WPE dereverberation algoritm. Args: input: C-channel complex-valued spectrogram, shape (B, C, F, N) input_length: Optional length for each signal in the batch, shape (B,) mask: Optional mask, shape (B, 1, F, N) or (B, C, F, N) Returns: Processed tensor with the same number of channels as the input, shape (B, C, F, N). """ io_dtype = input.dtype with torch.cuda.amp.autocast(enabled=False): output = input.cdouble() for i in range(self.num_iterations): magnitude = torch.abs(output) if i == 0 and mask is not None: # Apply thresholds mask = torch.clamp(mask, min=self.mask_min, max=self.mask_max) # Mask magnitude magnitude = mask * magnitude # Calculate power power = magnitude ** 2 # Apply filter output, output_length = self.filter(input=output, input_length=input_length, power=power) return output.to(io_dtype), output_length