| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import importlib |
| from typing import Optional |
|
|
| import numpy as np |
| import pytest |
| import torch |
|
|
| from nemo.collections.audio.modules.features import SpectrogramToMultichannelFeatures |
| from nemo.collections.audio.modules.masking import ( |
| MaskBasedDereverbWPE, |
| MaskEstimatorFlexChannels, |
| MaskEstimatorGSS, |
| MaskReferenceChannel, |
| ) |
| from nemo.collections.audio.modules.ssl_pretrain_masking import SSLPretrainWithMaskedPatch |
| from nemo.collections.audio.modules.transforms import AudioToSpectrogram |
| from nemo.collections.audio.parts.submodules.multichannel import WPEFilter |
| from nemo.collections.audio.parts.utils.audio import convmtx_mc_numpy |
| from nemo.utils import logging |
|
|
|
|
| class TestSpectrogramToMultichannelFeatures: |
| @pytest.mark.unit |
| @pytest.mark.parametrize('fft_length', [128]) |
| @pytest.mark.parametrize('num_channels', [1, 3]) |
| @pytest.mark.parametrize('mag_reduction', [None, 'rms', 'abs_mean', 'mean_abs']) |
| @pytest.mark.parametrize('mag_power', [None, 2]) |
| @pytest.mark.parametrize('mag_normalization', [None, 'mean', 'mean_var']) |
| def test_magnitude( |
| self, |
| fft_length: int, |
| num_channels: int, |
| mag_reduction: Optional[str], |
| mag_power: Optional[float], |
| mag_normalization: Optional[str], |
| ): |
| """Test calculation of spatial features for multi-channel audio.""" |
| atol = 5e-5 |
| batch_size = 8 |
| num_samples = fft_length * 50 |
| num_examples = 10 |
| random_seed = 42 |
|
|
| _rng = np.random.default_rng(seed=random_seed) |
|
|
| hop_length = fft_length // 4 |
| audio2spec = AudioToSpectrogram(fft_length=fft_length, hop_length=hop_length) |
|
|
| spec2feat = SpectrogramToMultichannelFeatures( |
| num_subbands=audio2spec.num_subbands, |
| mag_reduction=mag_reduction, |
| mag_power=mag_power, |
| mag_normalization=mag_normalization, |
| use_ipd=False, |
| ) |
|
|
| for n in range(num_examples): |
| x = _rng.normal(size=(batch_size, num_channels, num_samples)) |
|
|
| |
| spec, spec_len = audio2spec(input=torch.Tensor(x), input_length=torch.Tensor([num_samples] * batch_size)) |
|
|
| |
| feat, _ = spec2feat(input=spec, input_length=spec_len) |
| feat_np = feat.cpu().detach().numpy() |
|
|
| |
| spec_np = spec.cpu().detach().numpy() |
| if mag_reduction is None: |
| feat_golden = np.abs(spec_np) |
| elif mag_reduction == 'rms': |
| feat_golden = np.sqrt(np.mean(np.abs(spec_np) ** 2, axis=1, keepdims=True)) |
| elif mag_reduction == 'mean_abs': |
| feat_golden = np.mean(np.abs(spec_np), axis=1, keepdims=True) |
| elif mag_reduction == 'abs_mean': |
| feat_golden = np.abs(np.mean(spec_np, axis=1, keepdims=True)) |
| else: |
| raise NotImplementedError(f'Magnitude reduction {mag_reduction} not implemented') |
|
|
| if mag_power is not None: |
| feat_golden = np.power(feat_golden, mag_power) |
|
|
| if mag_normalization == 'mean': |
| feat_golden = feat_golden - np.mean(feat_golden, axis=(1, 3), keepdims=True) |
| elif mag_normalization == 'mean_var': |
| feat_golden = feat_golden - np.mean(feat_golden, axis=(1, 3), keepdims=True) |
| feat_golden = feat_golden / np.sqrt(np.mean(feat_golden**2, axis=(1, 3), keepdims=True)) |
|
|
| |
| assert feat_np.shape == feat_golden.shape, f'Feature shape not matching for example {n}' |
|
|
| |
| assert np.allclose(feat_np, feat_golden, atol=atol), f'Features not matching for example {n}' |
|
|
| @pytest.mark.unit |
| @pytest.mark.parametrize('fft_length', [128]) |
| @pytest.mark.parametrize('num_channels', [1, 3]) |
| @pytest.mark.parametrize('ipd_normalization', [None, 'mean', 'mean_var']) |
| @pytest.mark.parametrize('use_input_length', [True, False]) |
| def test_ipd(self, fft_length: int, num_channels: int, ipd_normalization: Optional[str], use_input_length: bool): |
| """Test calculation of IPD spatial features for multi-channel audio.""" |
| atol = 5e-5 |
| batch_size = 8 |
| num_samples = fft_length * 50 |
| num_examples = 10 |
| random_seed = 42 |
|
|
| _rng = np.random.default_rng(seed=random_seed) |
|
|
| hop_length = fft_length // 4 |
| audio2spec = AudioToSpectrogram(fft_length=fft_length, hop_length=hop_length) |
|
|
| spec2feat = SpectrogramToMultichannelFeatures( |
| num_subbands=audio2spec.num_subbands, |
| mag_reduction='rms', |
| use_ipd=True, |
| mag_normalization=None, |
| ipd_normalization=ipd_normalization, |
| ) |
|
|
| for n in range(num_examples): |
| x = _rng.normal(size=(batch_size, num_channels, num_samples)) |
|
|
| spec, spec_len = audio2spec(input=torch.Tensor(x), input_length=torch.Tensor([num_samples] * batch_size)) |
|
|
| |
| feat, _ = spec2feat(input=spec, input_length=spec_len if use_input_length else None) |
| feat_np = feat.cpu().detach().numpy() |
| ipd = feat_np[..., audio2spec.num_subbands :, :] |
|
|
| |
| spec_np = spec.cpu().detach().numpy() |
| spec_mean = np.mean(spec_np, axis=1, keepdims=True) |
| ipd_golden = np.angle(spec_np) - np.angle(spec_mean) |
| ipd_golden = np.remainder(ipd_golden + np.pi, 2 * np.pi) - np.pi |
|
|
| if ipd_normalization == 'mean': |
| ipd_golden = ipd_golden - np.mean(ipd_golden, axis=(1, 3), keepdims=True) |
| elif ipd_normalization == 'mean_var': |
| ipd_golden = ipd_golden - np.mean(ipd_golden, axis=(1, 3), keepdims=True) |
| ipd_golden = ipd_golden / np.sqrt( |
| np.maximum(np.mean(ipd_golden**2, axis=(1, 3), keepdims=True), spec2feat.eps) |
| ) |
|
|
| |
| assert ipd.shape == ipd_golden.shape, f'Feature shape not matching for example {n}' |
|
|
| |
| assert np.allclose(ipd, ipd_golden, atol=atol), f'Features not matching for example {n}' |
|
|
| @pytest.mark.unit |
| @pytest.mark.parametrize('use_ipd', [False, True]) |
| def test_num_channels(self, use_ipd: bool): |
| """Test num channels property.""" |
| uut = SpectrogramToMultichannelFeatures(num_subbands=32, use_ipd=use_ipd) |
| with pytest.raises(ValueError): |
| |
| uut.num_channels |
|
|
| for num_channels in [1, 2, 3, 4]: |
| |
| uut = SpectrogramToMultichannelFeatures(num_subbands=32, num_input_channels=num_channels, use_ipd=use_ipd) |
| assert uut.num_channels == num_channels |
|
|
| for num_channels in [1, 2, 3, 4]: |
| |
| uut = SpectrogramToMultichannelFeatures( |
| num_subbands=32, num_input_channels=num_channels, use_ipd=use_ipd, mag_reduction='rms' |
| ) |
| if use_ipd: |
| assert uut.num_channels == num_channels |
| else: |
| assert uut.num_channels == 1 |
|
|
| @pytest.mark.unit |
| @pytest.mark.parametrize('use_ipd', [False, True]) |
| def test_num_features(self, use_ipd: bool): |
| """Test num features property.""" |
| for num_subbands in [5, 10]: |
| uut = SpectrogramToMultichannelFeatures(num_subbands=num_subbands, use_ipd=use_ipd) |
| assert uut.num_features == 2 * num_subbands if use_ipd else num_subbands |
|
|
| @pytest.mark.unit |
| def test_unsupported_norm(self): |
| """Test initialization with unsupported normalization.""" |
| |
| with pytest.raises(NotImplementedError): |
| SpectrogramToMultichannelFeatures( |
| num_subbands=32, |
| mag_reduction='rms', |
| use_ipd=False, |
| mag_normalization='not-implemented', |
| ) |
| |
| with pytest.raises(NotImplementedError): |
| SpectrogramToMultichannelFeatures( |
| num_subbands=32, |
| use_ipd=True, |
| ipd_normalization='not-implemented', |
| ) |
| |
| uut = SpectrogramToMultichannelFeatures( |
| num_subbands=32, |
| mag_reduction='not-implemented', |
| ) |
| input = torch.randn(1, 3, 100, 100) |
| with pytest.raises(ValueError): |
| uut(input=input, input_length=torch.Tensor([100])) |
|
|
|
|
| class TestMaskBasedProcessor: |
| @pytest.mark.unit |
| @pytest.mark.parametrize('fft_length', [256]) |
| @pytest.mark.parametrize('num_channels', [1, 4]) |
| @pytest.mark.parametrize('num_masks', [1, 2]) |
| def test_mask_reference_channel(self, fft_length: int, num_channels: int, num_masks: int): |
| """Test masking of the reference channel.""" |
| if num_channels == 1: |
| |
| ref_channels = [0] |
| else: |
| |
| ref_channels = [0, num_channels - 1] |
|
|
| atol = 1e-6 |
| batch_size = 8 |
| num_samples = fft_length * 50 |
| num_examples = 10 |
| random_seed = 42 |
|
|
| _rng = np.random.default_rng(seed=random_seed) |
|
|
| hop_length = fft_length // 4 |
| audio2spec = AudioToSpectrogram(fft_length=fft_length, hop_length=hop_length) |
|
|
| for ref_channel in ref_channels: |
|
|
| mask_processor = MaskReferenceChannel(ref_channel=ref_channel) |
|
|
| for n in range(num_examples): |
| x = _rng.normal(size=(batch_size, num_channels, num_samples)) |
|
|
| spec, spec_len = audio2spec( |
| input=torch.Tensor(x), input_length=torch.Tensor([num_samples] * batch_size) |
| ) |
|
|
| |
| mask = _rng.uniform( |
| low=0.0, high=1.0, size=(batch_size, num_masks, audio2spec.num_subbands, spec.shape[-1]) |
| ) |
|
|
| |
| out, _ = mask_processor(input=spec, input_length=spec_len, mask=torch.tensor(mask)) |
| out_np = out.cpu().detach().numpy() |
|
|
| |
| spec_np = spec.cpu().detach().numpy() |
| out_golden = np.zeros_like(mask, dtype=spec_np.dtype) |
| for m in range(num_masks): |
| out_golden[:, m, ...] = spec_np[:, ref_channel, ...] * mask[:, m, ...] |
|
|
| |
| assert out_np.shape == out_golden.shape, f'Output shape not matching for example {n}' |
|
|
| |
| assert np.allclose(out_np, out_golden, atol=atol), f'Output not matching for example {n}' |
|
|
|
|
| class TestMaskBasedDereverb: |
| @pytest.mark.unit |
| @pytest.mark.parametrize('num_channels', [1, 3]) |
| @pytest.mark.parametrize('filter_length', [10]) |
| @pytest.mark.parametrize('delay', [0, 5]) |
| def test_wpe_convtensor(self, num_channels: int, filter_length: int, delay: int): |
| """Test construction of convolutional tensor in WPE. Compare against |
| reference implementation convmtx_mc. |
| """ |
| atol = 1e-6 |
| random_seed = 42 |
| num_examples = 10 |
| batch_size = 8 |
| num_subbands = 15 |
| num_frames = 21 |
|
|
| _rng = np.random.default_rng(seed=random_seed) |
| input_size = (batch_size, num_channels, num_subbands, num_frames) |
|
|
| for n in range(num_examples): |
| X = _rng.normal(size=input_size) + 1j * _rng.normal(size=input_size) |
|
|
| |
| tilde_X_ref = np.zeros((batch_size, num_subbands, num_frames, num_channels * filter_length), dtype=X.dtype) |
| for b in range(batch_size): |
| for f in range(num_subbands): |
| tilde_X_ref[b, f, :, :] = convmtx_mc_numpy( |
| X[b, :, f, :].transpose(), filter_length=filter_length, delay=delay |
| ) |
|
|
| |
| tilde_X_uut = WPEFilter.convtensor(torch.tensor(X), filter_length=filter_length, delay=delay) |
|
|
| |
| |
| tilde_X_uut = WPEFilter.permute_convtensor(tilde_X_uut) |
| tilde_X_uut = tilde_X_uut.cpu().detach().numpy() |
|
|
| assert np.allclose(tilde_X_uut, tilde_X_ref, atol=atol), f'Example {n}: comparison failed' |
|
|
| @pytest.mark.unit |
| @pytest.mark.parametrize('num_channels', [1, 3]) |
| @pytest.mark.parametrize('filter_length', [10]) |
| @pytest.mark.parametrize('delay', [0, 5]) |
| def test_wpe_filter(self, num_channels: int, filter_length: int, delay: int): |
| """Test estimation of correlation matrices, filter and filtering.""" |
| atol = 1e-6 |
| random_seed = 42 |
| num_examples = 10 |
| batch_size = 4 |
| num_subbands = 15 |
| num_frames = 50 |
|
|
| wpe_filter = WPEFilter(filter_length=filter_length, prediction_delay=delay, diag_reg=None) |
|
|
| _rng = np.random.default_rng(seed=random_seed) |
| input_size = (batch_size, num_channels, num_subbands, num_frames) |
|
|
| for n in range(num_examples): |
| X = torch.tensor(_rng.normal(size=input_size) + 1j * _rng.normal(size=input_size)) |
| weight = torch.tensor(_rng.uniform(size=(batch_size, num_subbands, num_frames))) |
|
|
| |
| tilde_X = wpe_filter.convtensor(X, filter_length=filter_length, delay=delay) |
|
|
| |
| |
|
|
| |
| |
| X_golden = X.permute(0, 2, 3, 1) |
| |
| tilde_X_golden = tilde_X.permute(0, 2, 3, 1, 4).reshape( |
| batch_size, num_subbands, num_frames, num_channels * filter_length |
| ) |
| |
| Q_golden = torch.matmul(tilde_X_golden.transpose(-1, -2).conj(), weight[..., None] * tilde_X_golden) |
| |
| R_golden = torch.matmul(tilde_X_golden.transpose(-1, -2).conj(), weight[..., None] * X_golden) |
|
|
| |
| Q_uut, R_uut = wpe_filter.estimate_correlations(input=X, weight=weight, tilde_input=tilde_X) |
| |
| Q_uut_flattened = Q_uut.flatten(start_dim=-2, end_dim=-1).flatten(start_dim=-3, end_dim=-2) |
| |
| R_uut_flattened = R_uut.flatten(start_dim=-3, end_dim=-2) |
|
|
| assert torch.allclose(Q_uut_flattened, Q_golden, atol=atol), f'Example {n}: comparison failed for Q' |
| assert torch.allclose(R_uut_flattened, R_golden, atol=atol), f'Example {n}: comparison failed for R' |
|
|
| |
| |
|
|
| |
| G_golden = torch.linalg.solve(Q_golden, R_golden) |
|
|
| |
| G_uut = wpe_filter.estimate_filter(Q_uut, R_uut) |
| |
| G_uut_flattened = G_uut.reshape(batch_size, num_channels, num_subbands, -1).permute(0, 2, 3, 1) |
|
|
| assert torch.allclose(G_uut_flattened, G_golden, atol=atol), f'Example {n}: comparison failed for G' |
|
|
| |
| |
|
|
| |
| U_golden = torch.matmul(tilde_X_golden, G_golden) |
|
|
| |
| U_uut = wpe_filter.apply_filter(filter=G_uut, tilde_input=tilde_X) |
| U_uut_ref = U_uut.permute(0, 2, 3, 1) |
|
|
| assert torch.allclose( |
| U_uut_ref, U_golden, atol=atol |
| ), f'Example {n}: comparison failed for undesired output U' |
|
|
| @pytest.mark.unit |
| @pytest.mark.parametrize('num_channels', [3]) |
| @pytest.mark.parametrize('filter_length', [5]) |
| @pytest.mark.parametrize('delay', [0, 2]) |
| def test_mask_based_dereverb_init(self, num_channels: int, filter_length: int, delay: int): |
| """Test that dereverb can be initialized and can process audio.""" |
| num_examples = 10 |
| batch_size = 8 |
| num_subbands = 15 |
| num_frames = 21 |
| num_iterations = 2 |
|
|
| input_size = (batch_size, num_subbands, num_frames, num_channels) |
|
|
| dereverb = MaskBasedDereverbWPE( |
| filter_length=filter_length, prediction_delay=delay, num_iterations=num_iterations |
| ) |
|
|
| for n in range(num_examples): |
| |
| x = torch.randn(input_size) + 1j * torch.randn(input_size) |
| |
| x_length = torch.randint(1, num_frames, (batch_size,)) |
| |
| mask = torch.rand(input_size) |
|
|
| |
| y, y_length = dereverb(input=x, input_length=x_length, mask=mask) |
|
|
| assert y.shape == x.shape, 'Output shape not matching, example {n}' |
| assert torch.equal(y_length, x_length), 'Length not matching, example {n}' |
|
|
|
|
| class TestMaskEstimator: |
| @pytest.mark.unit |
| @pytest.mark.parametrize('channel_reduction_position', [0, 1, -1]) |
| @pytest.mark.parametrize('channel_reduction_type', ['average', 'attention']) |
| @pytest.mark.parametrize('channel_block_type', ['transform_average_concatenate', 'transform_attend_concatenate']) |
| def test_flex_channels( |
| self, channel_reduction_position: int, channel_reduction_type: str, channel_block_type: str |
| ): |
| """Test initialization of the mask estimator and make sure it can process input tensor.""" |
| |
| num_subbands_tests = [32, 65] |
| num_outputs_tests = [1, 2] |
| num_blocks_tests = [1, 5] |
|
|
| |
| num_channels_tests = [1, 4] |
| batch_size = 4 |
| num_frames = 50 |
|
|
| for num_subbands in num_subbands_tests: |
| for num_outputs in num_outputs_tests: |
| for num_blocks in num_blocks_tests: |
| logging.debug( |
| 'Instantiate with num_subbands=%d, num_outputs=%d, num_blocks=%d', |
| num_subbands, |
| num_outputs, |
| num_blocks, |
| ) |
|
|
| |
| uut = MaskEstimatorFlexChannels( |
| num_outputs=num_outputs, |
| num_subbands=num_subbands, |
| num_blocks=num_blocks, |
| channel_reduction_position=channel_reduction_position, |
| channel_reduction_type=channel_reduction_type, |
| channel_block_type=channel_block_type, |
| ) |
|
|
| |
| for num_channels in num_channels_tests: |
| logging.debug('Process num_channels=%d', num_channels) |
| input_size = (batch_size, num_channels, num_subbands, num_frames) |
|
|
| |
| spec = torch.randn(input_size, dtype=torch.cfloat) |
| spec_length = torch.randint(1, num_frames, (batch_size,)) |
|
|
| |
| mask, mask_length = uut(input=spec, input_length=spec_length) |
|
|
| |
| expected_mask_shape = (batch_size, num_outputs, num_subbands, num_frames) |
| assert ( |
| mask.shape == expected_mask_shape |
| ), f'Output shape mismatch: expected {expected_mask_shape}, got {mask.shape}' |
|
|
| |
| assert torch.all( |
| mask_length == spec_length |
| ), f'Output length mismatch: expected {spec_length}, got {mask_length}' |
|
|
| @pytest.mark.unit |
| @pytest.mark.parametrize('num_channels', [1, 4]) |
| @pytest.mark.parametrize('num_subbands', [32, 65]) |
| @pytest.mark.parametrize('num_outputs', [2, 3]) |
| @pytest.mark.parametrize('batch_size', [1, 4]) |
| def test_gss(self, num_channels: int, num_subbands: int, num_outputs: int, batch_size: int): |
| """Test initialization of the GSS mask estimator and make sure it can process an input tensor. |
| This tests initialization and the output shape. It does not test correctness of the output. |
| """ |
| |
| num_frames = 50 |
|
|
| |
| uut = MaskEstimatorGSS() |
|
|
| |
| logging.debug('Process num_channels=%d', num_channels) |
| input_size = (batch_size, num_channels, num_subbands, num_frames) |
| logging.debug('Input size: %s', input_size) |
|
|
| |
| mixture_spec = torch.randn(input_size, dtype=torch.cfloat) |
| source_activity = torch.randn(batch_size, num_outputs, num_frames) > 0 |
|
|
| |
| mask = uut(input=mixture_spec, activity=source_activity) |
|
|
| |
| expected_mask_shape = (batch_size, num_outputs, num_subbands, num_frames) |
| assert ( |
| mask.shape == expected_mask_shape |
| ), f'Output shape mismatch: expected {expected_mask_shape}, got {mask.shape}' |
|
|
|
|
| class TestSSLPretrainMaskingWithPatch: |
| @pytest.mark.unit |
| @pytest.mark.parametrize('patch_size', [1, 5, 10]) |
| @pytest.mark.parametrize('mask_fraction', [0.5, 1.0]) |
| @pytest.mark.parametrize('training', [True, False]) |
| def test_masking(self, patch_size: int, mask_fraction: float, training: bool): |
| """Test SSL pretrain masking.""" |
| num_subbands = 32 |
| num_frames = 5000 |
| num_channels = 1 |
| batch_size = 8 |
| abs_tol = 1e-2 |
|
|
| |
| uut = SSLPretrainWithMaskedPatch(patch_size=patch_size, mask_fraction=mask_fraction) |
|
|
| |
| if training: |
| uut.train() |
| else: |
| uut.eval() |
|
|
| |
| rng = torch.Generator() |
| rng.manual_seed(0) |
| input_spec = torch.randn(batch_size, num_channels, num_subbands, num_frames, dtype=torch.cfloat, generator=rng) |
| input_length = torch.randint(num_frames // 2, num_frames, (batch_size,), generator=rng) |
| for b in range(batch_size): |
| input_spec[b, :, :, input_length[b] :] = 0.0 |
|
|
| |
| masked_spec = uut(input_spec=input_spec, length=input_length) |
|
|
| |
| assert masked_spec.shape == input_spec.shape |
|
|
| |
| for b in range(batch_size): |
| |
| est_mask_fraction = torch.sum(masked_spec[b, :, :, : input_length[b]].abs() == 0.0) / ( |
| num_channels * num_subbands * input_length[b] |
| ) |
|
|
| |
| assert ( |
| abs(est_mask_fraction - mask_fraction) < abs_tol |
| ), f'Example {b}: est_mask_fraction = {est_mask_fraction}, mask_fraction = {mask_fraction}' |
|
|
| @pytest.mark.unit |
| def test_unsupported_initialization(self): |
| """Test SSL pretrain masking.""" |
| with pytest.raises(ValueError): |
| SSLPretrainWithMaskedPatch(patch_size=0) |
| with pytest.raises(ValueError): |
| SSLPretrainWithMaskedPatch(patch_size=-1) |
| with pytest.raises(ValueError): |
| SSLPretrainWithMaskedPatch(mask_fraction=1.1) |
| with pytest.raises(ValueError): |
| SSLPretrainWithMaskedPatch(mask_fraction=-0.1) |
|
|