| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import pytest |
| import torch |
|
|
| from nemo.collections.audio.modules.projections import MixtureConsistencyProjection |
|
|
|
|
| class TestMixtureConsistencyProjection: |
| @pytest.mark.unit |
| @pytest.mark.parametrize('weighting', [None, 'power']) |
| @pytest.mark.parametrize('num_sources', [1, 3]) |
| def test_mixture_consistency(self, weighting: str, num_sources: int): |
| batch_size = 4 |
| num_subbands = 33 |
| num_samples = 100 |
| num_examples = 8 |
| atol = 1e-5 |
|
|
| rng = torch.Generator() |
| rng.manual_seed(42) |
|
|
| |
| uut = MixtureConsistencyProjection(weighting=weighting) |
|
|
| for n in range(num_examples): |
| |
| mixture = torch.randn(batch_size, 1, num_subbands, num_samples, generator=rng, dtype=torch.cfloat) |
| |
| estimate = torch.randn( |
| batch_size, num_sources, num_subbands, num_samples, generator=rng, dtype=torch.cfloat |
| ) |
|
|
| |
| uut_projected = uut(mixture=mixture, estimate=estimate) |
|
|
| |
| estimated_mixture = torch.sum(estimate, dim=1, keepdim=True) |
|
|
| if weighting is None: |
| weight = 1 / num_sources |
| elif weighting == 'power': |
| weight = estimate.abs().pow(2) |
| weight = weight / (weight.sum(dim=1, keepdim=True) + uut.eps) |
| else: |
| raise ValueError(f'Weighting {weighting} not implemented') |
| correction = weight * (mixture - estimated_mixture) |
| ref_projected = estimate + correction |
|
|
| |
| assert torch.allclose(uut_projected, ref_projected, atol=atol) |
|
|
| @pytest.mark.unit |
| def test_unsupported_weighting(self): |
| |
| with pytest.raises(NotImplementedError): |
| MixtureConsistencyProjection(weighting='not-implemented') |
|
|
| |
| uut = MixtureConsistencyProjection(weighting=None) |
| uut.weighting = 'not-implemented' |
| with pytest.raises(NotImplementedError): |
| uut( |
| mixture=torch.randn(1, 1, 1, 1, dtype=torch.cfloat), |
| estimate=torch.randn(1, 1, 1, 1, dtype=torch.cfloat), |
| ) |
|
|
| @pytest.mark.unit |
| def test_unsupported_inputs(self): |
| |
| uut = MixtureConsistencyProjection(weighting=None) |
| with pytest.raises(ValueError): |
| uut( |
| mixture=torch.randn(1, 2, 1, 1, dtype=torch.cfloat), |
| estimate=torch.randn(1, 2, 1, 1, dtype=torch.cfloat), |
| ) |
|
|
| |
| |
| with pytest.raises(TypeError): |
| uut(mixture=torch.randn(1, 2, 1), estimate=torch.randn(1, 2, 1)) |
| |
| with pytest.raises(TypeError): |
| uut(mixture=torch.randn(1, 1, 1, 1, dtype=torch.cfloat), estimate=torch.randn(1, 2, 1)) |
|
|