Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| from asteroid.models.base_models import ( | |
| BaseEncoderMaskerDecoder, | |
| _unsqueeze_to_3d, | |
| _shape_reconstructed, | |
| ) | |
| from asteroid.utils.torch_utils import pad_x_to_y, jitable_shape | |
| from einops import rearrange | |
| class BaseEncoderMaskerDecoderWithConfigs(BaseEncoderMaskerDecoder): | |
| def __init__(self, encoder, masker, decoder, encoder_activation=None, **kwargs): | |
| super().__init__(encoder, masker, decoder, encoder_activation) | |
| self.use_encoder = kwargs.get("use_encoder", True) | |
| self.apply_mask = kwargs.get("apply_mask", True) | |
| self.use_decoder = kwargs.get("use_decoder", True) | |
| def forward(self, wav): | |
| """ | |
| Enc/Mask/Dec model forward with some additional options. | |
| Some of the models we use, like TFC-TDF-UNet, have no masker. | |
| In UMX or X-UMX, they already use masking in their model implementation. | |
| Since we do not want to manipulate the model codes, we use this wrapper. | |
| Args: | |
| wav (torch.Tensor): waveform tensor. 1D, 2D or 3D tensor, time last. | |
| Returns: | |
| torch.Tensor, of shape (batch, n_src, time) or (n_src, time). | |
| """ | |
| # Remember shape to shape reconstruction, cast to Tensor for torchscript | |
| shape = jitable_shape(wav) | |
| # Reshape to (batch, n_mix, time) | |
| wav = _unsqueeze_to_3d(wav) | |
| # Real forward | |
| if self.use_encoder: | |
| tf_rep = self.forward_encoder(wav) | |
| else: | |
| tf_rep = wav | |
| est_masks = self.forward_masker(tf_rep) | |
| if self.apply_mask: | |
| masked_tf_rep = self.apply_masks(tf_rep, est_masks) | |
| else: # model already used masking | |
| masked_tf_rep = est_masks | |
| if self.use_decoder: | |
| decoded = self.forward_decoder(masked_tf_rep) | |
| reconstructed = pad_x_to_y(decoded, wav) | |
| return masked_tf_rep, _shape_reconstructed(reconstructed, shape) | |
| else: # In UMX or X-UMX, decoder is not used | |
| decoded = masked_tf_rep | |
| return decoded | |
| class BaseEncoderMaskerDecoder_mixture_consistency(BaseEncoderMaskerDecoder): | |
| def __init__(self, encoder, masker, decoder, encoder_activation=None): | |
| super().__init__(encoder, masker, decoder, encoder_activation) | |
| def forward(self, wav): | |
| """Enc/Mask/Dec model forward with mixture consistent output | |
| References: | |
| [1] : Wisdom, Scott, et al. "Differentiable consistency constraints for improved deep speech enhancement." ICASSP 2019. | |
| [2] : Wisdom, Scott, et al. "Unsupervised sound separation using mixture invariant training." NeurIPS 2020. | |
| Args: | |
| wav (torch.Tensor): waveform tensor. 1D, 2D or 3D tensor, time last. | |
| Returns: | |
| torch.Tensor, of shape (batch, n_src, time) or (n_src, time). | |
| """ | |
| # Remember shape to shape reconstruction, cast to Tensor for torchscript | |
| shape = jitable_shape(wav) | |
| # Reshape to (batch, n_mix, time) | |
| wav = _unsqueeze_to_3d(wav) | |
| # Real forward | |
| tf_rep = self.forward_encoder(wav) | |
| est_masks = self.forward_masker(tf_rep) | |
| masked_tf_rep = self.apply_masks(tf_rep, est_masks) | |
| decoded = self.forward_decoder(masked_tf_rep) | |
| reconstructed = _shape_reconstructed(pad_x_to_y(decoded, wav), shape) | |
| reconstructed = reconstructed + 1 / reconstructed.shape[1] * ( | |
| wav - reconstructed.sum(dim=1, keepdim=True) | |
| ) | |
| return reconstructed | |
| class BaseEncoderMaskerDecoderWithConfigsMaskOnOutput(BaseEncoderMaskerDecoder): | |
| def __init__(self, encoder, masker, decoder, encoder_activation=None, **kwargs): | |
| super().__init__(encoder, masker, decoder, encoder_activation) | |
| self.use_encoder = kwargs.get("use_encoder", True) | |
| self.apply_mask = kwargs.get("apply_mask", True) | |
| self.use_decoder = kwargs.get("use_decoder", True) | |
| self.nb_channels = kwargs.get("nb_channels", 2) | |
| self.decoder_activation = kwargs.get("decoder_activation", "sigmoid") | |
| if self.decoder_activation == "sigmoid": | |
| self.act_after_dec = nn.Sigmoid() | |
| elif self.decoder_activation == "relu": | |
| self.act_after_dec = nn.ReLU() | |
| elif self.decoder_activation == "relu6": | |
| self.act_after_dec = nn.ReLU6() | |
| elif self.decoder_activation == "tanh": | |
| self.act_after_dec = nn.Tanh() | |
| elif self.decoder_activation == "none": | |
| self.act_after_dec = nn.Identity() | |
| else: | |
| self.act_after_dec = nn.Sigmoid() | |
| def forward(self, wav): | |
| """ | |
| For the De-limit task, we will apply the mask on the output of the decoder. | |
| We want decoder to learn the sample-wise ratio of the sources. | |
| Args: | |
| wav (torch.Tensor): waveform tensor. 1D, 2D or 3D tensor, time last. | |
| Returns: | |
| torch.Tensor, of shape (batch, n_src, time) or (n_src, time). | |
| """ | |
| # Remember shape to shape reconstruction, cast to Tensor for torchscript | |
| shape = jitable_shape(wav) | |
| # Reshape to (batch, n_mix, time) | |
| wav = _unsqueeze_to_3d(wav) # (batch, n_channels, time) | |
| # Real forward | |
| if self.use_encoder: | |
| tf_rep = self.forward_encoder(wav) # (batch, n_channels, freq, time) | |
| else: | |
| tf_rep = wav | |
| if self.nb_channels == 2: | |
| tf_rep = rearrange( | |
| tf_rep, "b c f t -> b (c f) t" | |
| ) # c == 2 when stereo input. | |
| est_masks = self.forward_masker(tf_rep) # (batch, 1, freq, time) | |
| # we are going to apply the mask on the output of the decoder | |
| if self.use_decoder: | |
| if self.nb_channels == 2: | |
| est_masks = rearrange(est_masks, "b 1 f t -> b f t") | |
| est_masks_decoded = self.forward_decoder(est_masks) | |
| est_masks_decoded = pad_x_to_y(est_masks_decoded, wav) # (batch, 1, time) | |
| est_masks_decoded = self.act_after_dec( | |
| est_masks_decoded | |
| ) # (batch, 1, time) | |
| decoded = wav * est_masks_decoded # (batch, n_channels, time) | |
| return ( | |
| est_masks_decoded, | |
| decoded, | |
| ) | |
| else: | |
| decoded = est_masks | |
| return (decoded,) | |
| class BaseEncoderMaskerDecoderWithConfigsMultiChannelAsteroid(BaseEncoderMaskerDecoder): | |
| def __init__(self, encoder, masker, decoder, encoder_activation=None, **kwargs): | |
| super().__init__(encoder, masker, decoder, encoder_activation) | |
| self.use_encoder = kwargs.get("use_encoder", True) | |
| self.apply_mask = kwargs.get("apply_mask", True) | |
| self.use_decoder = kwargs.get("use_decoder", True) | |
| self.nb_channels = kwargs.get("nb_channels", 2) | |
| self.decoder_activation = kwargs.get("decoder_activation", "none") | |
| if self.decoder_activation == "sigmoid": | |
| self.act_after_dec = nn.Sigmoid() | |
| elif self.decoder_activation == "relu": | |
| self.act_after_dec = nn.ReLU() | |
| elif self.decoder_activation == "relu6": | |
| self.act_after_dec = nn.ReLU6() | |
| elif self.decoder_activation == "tanh": | |
| self.act_after_dec = nn.Tanh() | |
| elif self.decoder_activation == "none": | |
| self.act_after_dec = nn.Identity() | |
| else: | |
| self.act_after_dec = nn.Sigmoid() | |
| def forward(self, wav): | |
| """ | |
| Enc/Mask/Dec model forward with some additional options. | |
| For MultiChannel usage of asteroid-based models. (e.g. ConvTasNet) | |
| Args: | |
| wav (torch.Tensor): waveform tensor. 1D, 2D or 3D tensor, time last. | |
| Returns: | |
| torch.Tensor, of shape (batch, n_src, time) or (n_src, time). | |
| """ | |
| # Remember shape to shape reconstruction, cast to Tensor for torchscript | |
| shape = jitable_shape(wav) | |
| # Reshape to (batch, n_mix, time) | |
| wav = _unsqueeze_to_3d(wav) | |
| # Real forward | |
| if self.use_encoder: | |
| tf_rep = self.forward_encoder(wav) | |
| else: | |
| tf_rep = wav | |
| if self.nb_channels == 2: | |
| tf_rep = rearrange( | |
| tf_rep, "b c f t -> b (c f) t" | |
| ) # c == 2 when stereo input. | |
| est_masks = self.forward_masker(tf_rep) | |
| if self.nb_channels == 2: | |
| tf_rep = rearrange(tf_rep, "b (c f) t -> b c f t", c=self.nb_channels) | |
| if self.apply_mask: | |
| # Since original asteroid implementation of masking includes unnecessary unsqueeze operation, we will do it manually. | |
| masked_tf_rep = est_masks * tf_rep | |
| else: | |
| masked_tf_rep = est_masks | |
| if self.use_decoder: | |
| decoded = self.forward_decoder(masked_tf_rep) | |
| reconstructed = pad_x_to_y(decoded, wav) | |
| reconstructed = self.act_after_dec(reconstructed) | |
| return masked_tf_rep, _shape_reconstructed(reconstructed, shape) | |
| else: | |
| decoded = masked_tf_rep | |
| return decoded | |