# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib from typing import Optional import numpy as np import pytest import torch from nemo.collections.asr.modules.audio_modules import ( MaskBasedDereverbWPE, MaskReferenceChannel, SpectrogramToMultichannelFeatures, WPEFilter, ) from nemo.collections.asr.modules.audio_preprocessing import AudioToSpectrogram from nemo.collections.asr.parts.utils.audio_utils import convmtx_mc_numpy try: importlib.import_module('torchaudio') HAVE_TORCHAUDIO = True except ModuleNotFoundError: HAVE_TORCHAUDIO = False class TestSpectrogramToMultichannelFeatures: @pytest.mark.unit @pytest.mark.skipif(not HAVE_TORCHAUDIO, reason="Modules in this test require torchaudio") @pytest.mark.parametrize('fft_length', [256]) @pytest.mark.parametrize('num_channels', [1, 4]) @pytest.mark.parametrize('mag_reduction', [None, 'rms', 'abs_mean', 'mean_abs']) def test_magnitude(self, fft_length: int, num_channels: int, mag_reduction: Optional[str]): """Test calculation of spatial features for multi-channel audio. """ atol = 1e-6 batch_size = 8 num_samples = fft_length * 50 num_examples = 25 random_seed = 42 _rng = np.random.default_rng(seed=random_seed) hop_length = fft_length // 4 audio2spec = AudioToSpectrogram(fft_length=fft_length, hop_length=hop_length) spec2feat = SpectrogramToMultichannelFeatures( num_subbands=audio2spec.num_subbands, mag_reduction=mag_reduction, use_ipd=False, mag_normalization=None, ) for n in range(num_examples): x = _rng.normal(size=(batch_size, num_channels, num_samples)) spec, spec_len = audio2spec(input=torch.Tensor(x), input_length=torch.Tensor([num_samples] * batch_size)) # UUT output feat, _ = spec2feat(input=spec, input_length=spec_len) feat_np = feat.cpu().detach().numpy() # Golden output spec_np = spec.cpu().detach().numpy() if mag_reduction is None: feat_golden = np.abs(spec_np) elif mag_reduction == 'rms': feat_golden = np.sqrt(np.mean(np.abs(spec_np) ** 2, axis=1, keepdims=True)) elif mag_reduction == 'mean_abs': feat_golden = np.mean(np.abs(spec_np), axis=1, keepdims=True) elif mag_reduction == 'abs_mean': feat_golden = np.abs(np.mean(spec_np, axis=1, keepdims=True)) else: raise NotImplementedError() # Compare shape assert feat_np.shape == feat_golden.shape, f'Feature shape not matching for example {n}' # Compare values assert np.allclose(feat_np, feat_golden, atol=atol), f'Features not matching for example {n}' @pytest.mark.unit @pytest.mark.skipif(not HAVE_TORCHAUDIO, reason="Modules in this test require torchaudio") @pytest.mark.parametrize('fft_length', [256]) @pytest.mark.parametrize('num_channels', [1, 4]) def test_ipd(self, fft_length: int, num_channels: int): """Test calculation of IPD spatial features for multi-channel audio. """ atol = 1e-5 batch_size = 8 num_samples = fft_length * 50 num_examples = 10 random_seed = 42 _rng = np.random.default_rng(seed=random_seed) hop_length = fft_length // 4 audio2spec = AudioToSpectrogram(fft_length=fft_length, hop_length=hop_length) spec2feat = SpectrogramToMultichannelFeatures( num_subbands=audio2spec.num_subbands, mag_reduction='rms', use_ipd=True, mag_normalization=None, ipd_normalization=None, ) for n in range(num_examples): x = _rng.normal(size=(batch_size, num_channels, num_samples)) spec, spec_len = audio2spec(input=torch.Tensor(x), input_length=torch.Tensor([num_samples] * batch_size)) # UUT output feat, _ = spec2feat(input=spec, input_length=spec_len) feat_np = feat.cpu().detach().numpy() ipd = feat_np[..., audio2spec.num_subbands :, :] # Golden output spec_np = spec.cpu().detach().numpy() spec_mean = np.mean(spec_np, axis=1, keepdims=True) ipd_golden = np.angle(spec_np) - np.angle(spec_mean) ipd_golden = np.remainder(ipd_golden + np.pi, 2 * np.pi) - np.pi # Compare shape assert ipd.shape == ipd_golden.shape, f'Feature shape not matching for example {n}' # Compare values assert np.allclose(ipd, ipd_golden, atol=atol), f'Features not matching for example {n}' class TestMaskBasedProcessor: @pytest.mark.unit @pytest.mark.skipif(not HAVE_TORCHAUDIO, reason="Modules in this test require torchaudio") @pytest.mark.parametrize('fft_length', [256]) @pytest.mark.parametrize('num_channels', [1, 4]) @pytest.mark.parametrize('num_masks', [1, 2]) def test_mask_reference_channel(self, fft_length: int, num_channels: int, num_masks: int): """Test masking of the reference channel. """ if num_channels == 1: # Only one channel available ref_channels = [0] else: # Use first or last channel for MC signals ref_channels = [0, num_channels - 1] atol = 1e-6 batch_size = 8 num_samples = fft_length * 50 num_examples = 10 random_seed = 42 _rng = np.random.default_rng(seed=random_seed) hop_length = fft_length // 4 audio2spec = AudioToSpectrogram(fft_length=fft_length, hop_length=hop_length) for ref_channel in ref_channels: mask_processor = MaskReferenceChannel(ref_channel=ref_channel) for n in range(num_examples): x = _rng.normal(size=(batch_size, num_channels, num_samples)) spec, spec_len = audio2spec( input=torch.Tensor(x), input_length=torch.Tensor([num_samples] * batch_size) ) # Randomly-generated mask mask = _rng.uniform( low=0.0, high=1.0, size=(batch_size, num_masks, audio2spec.num_subbands, spec.shape[-1]) ) # UUT output out, _ = mask_processor(input=spec, input_length=spec_len, mask=torch.tensor(mask)) out_np = out.cpu().detach().numpy() # Golden output spec_np = spec.cpu().detach().numpy() out_golden = np.zeros_like(mask, dtype=spec_np.dtype) for m in range(num_masks): out_golden[:, m, ...] = spec_np[:, ref_channel, ...] * mask[:, m, ...] # Compare shape assert out_np.shape == out_golden.shape, f'Output shape not matching for example {n}' # Compare values assert np.allclose(out_np, out_golden, atol=atol), f'Output not matching for example {n}' class TestMaskBasedDereverb: @pytest.mark.unit @pytest.mark.parametrize('num_channels', [1, 3]) @pytest.mark.parametrize('filter_length', [10]) @pytest.mark.parametrize('delay', [0, 5]) def test_wpe_convtensor(self, num_channels: int, filter_length: int, delay: int): """Test construction of convolutional tensor in WPE. Compare against reference implementation convmtx_mc. """ atol = 1e-6 random_seed = 42 num_examples = 10 batch_size = 8 num_subbands = 15 num_frames = 21 _rng = np.random.default_rng(seed=random_seed) input_size = (batch_size, num_channels, num_subbands, num_frames) for n in range(num_examples): X = _rng.normal(size=input_size) + 1j * _rng.normal(size=input_size) # Reference tilde_X_ref = np.zeros((batch_size, num_subbands, num_frames, num_channels * filter_length), dtype=X.dtype) for b in range(batch_size): for f in range(num_subbands): tilde_X_ref[b, f, :, :] = convmtx_mc_numpy( X[b, :, f, :].transpose(), filter_length=filter_length, delay=delay ) # UUT tilde_X_uut = WPEFilter.convtensor(torch.tensor(X), filter_length=filter_length, delay=delay) # UUT has vectors arranged in a tensor shape with permuted columns # Reorganize to match the shape and column permutation tilde_X_uut = WPEFilter.permute_convtensor(tilde_X_uut) tilde_X_uut = tilde_X_uut.cpu().detach().numpy() assert np.allclose(tilde_X_uut, tilde_X_ref, atol=atol), f'Example {n}: comparison failed' @pytest.mark.unit @pytest.mark.parametrize('num_channels', [1, 3]) @pytest.mark.parametrize('filter_length', [10]) @pytest.mark.parametrize('delay', [0, 5]) def test_wpe_filter(self, num_channels: int, filter_length: int, delay: int): """Test estimation of correlation matrices, filter and filtering. """ atol = 1e-6 random_seed = 42 num_examples = 10 batch_size = 4 num_subbands = 15 num_frames = 50 wpe_filter = WPEFilter(filter_length=filter_length, prediction_delay=delay, diag_reg=None) _rng = np.random.default_rng(seed=random_seed) input_size = (batch_size, num_channels, num_subbands, num_frames) for n in range(num_examples): X = torch.tensor(_rng.normal(size=input_size) + 1j * _rng.normal(size=input_size)) weight = torch.tensor(_rng.uniform(size=(batch_size, num_subbands, num_frames))) # Create convtensor (B, C, F, N, filter_length) tilde_X = wpe_filter.convtensor(X, filter_length=filter_length, delay=delay) # Test 1: # estimate_correlation # Reference # move channels to back X_golden = X.permute(0, 2, 3, 1) # move channels to back and reshape to (B, F, N, C*filter_length) tilde_X_golden = tilde_X.permute(0, 2, 3, 1, 4).reshape( batch_size, num_subbands, num_frames, num_channels * filter_length ) # (B, F, C * filter_length, C * filter_length) Q_golden = torch.matmul(tilde_X_golden.transpose(-1, -2).conj(), weight[..., None] * tilde_X_golden) # (B, F, C * filter_length, C) R_golden = torch.matmul(tilde_X_golden.transpose(-1, -2).conj(), weight[..., None] * X_golden) # UUT Q_uut, R_uut = wpe_filter.estimate_correlations(input=X, weight=weight, tilde_input=tilde_X) # Flatten (B, F, C, filter_length, C, filter_length) into (B, F, C*filter_length, C*filter_length) Q_uut_flattened = Q_uut.flatten(start_dim=-2, end_dim=-1).flatten(start_dim=-3, end_dim=-2) # Flatten (B, F, C, filter_length, C, filter_length) into (B, F, C*filter_length, C*filter_length) R_uut_flattened = R_uut.flatten(start_dim=-3, end_dim=-2) assert torch.allclose(Q_uut_flattened, Q_golden, atol=atol), f'Example {n}: comparison failed for Q' assert torch.allclose(R_uut_flattened, R_golden, atol=atol), f'Example {n}: comparison failed for R' # Test 2: # estimate_filter # Reference G_golden = torch.linalg.solve(Q_golden, R_golden) # UUT G_uut = wpe_filter.estimate_filter(Q_uut, R_uut) # Flatten and move output channels to back G_uut_flattened = G_uut.reshape(batch_size, num_channels, num_subbands, -1).permute(0, 2, 3, 1) assert torch.allclose(G_uut_flattened, G_golden, atol=atol), f'Example {n}: comparison failed for G' # Test 3: # apply_filter # Reference U_golden = torch.matmul(tilde_X_golden, G_golden) # UUT U_uut = wpe_filter.apply_filter(filter=G_uut, tilde_input=tilde_X) U_uut_ref = U_uut.permute(0, 2, 3, 1) assert torch.allclose( U_uut_ref, U_golden, atol=atol ), f'Example {n}: comparison failed for undesired output U' @pytest.mark.unit @pytest.mark.parametrize('num_channels', [3]) @pytest.mark.parametrize('filter_length', [5]) @pytest.mark.parametrize('delay', [0, 2]) def test_mask_based_dereverb_init(self, num_channels: int, filter_length: int, delay: int): """Test that dereverb can be initialized and can process audio. """ num_examples = 10 batch_size = 8 num_subbands = 15 num_frames = 21 num_iterations = 2 input_size = (batch_size, num_subbands, num_frames, num_channels) dereverb = MaskBasedDereverbWPE( filter_length=filter_length, prediction_delay=delay, num_iterations=num_iterations ) for n in range(num_examples): # multi-channel input x = torch.randn(input_size) + 1j * torch.randn(input_size) # random input_length x_length = torch.randint(1, num_frames, (batch_size,)) # multi-channel mask mask = torch.rand(input_size) # UUT y, y_length = dereverb(input=x, input_length=x_length, mask=mask) assert y.shape == x.shape, 'Output shape not matching, example {n}' assert torch.equal(y_length, x_length), 'Length not matching, example {n}'