| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import itertools |
| | import pytest |
| | import torch |
| |
|
| | from nemo.collections.asr.parts.utils.asr_multispeaker_utils import ( |
| | find_best_permutation, |
| | find_first_nonzero, |
| | get_ats_targets, |
| | get_hidden_length_from_sample_length, |
| | get_pil_targets, |
| | reconstruct_labels, |
| | ) |
| |
|
| |
|
| | def reconstruct_labels_forloop(labels: torch.Tensor, batch_perm_inds: torch.Tensor) -> torch.Tensor: |
| | """ |
| | This is a for-loop implementation of reconstruct_labels built for testing purposes. |
| | """ |
| | |
| | batch_size, num_frames, num_speakers = labels.shape |
| | batch_perm_inds_exp = batch_perm_inds.unsqueeze(1).expand(-1, num_frames, -1) |
| |
|
| | |
| | reconstructed_labels = torch.gather(labels, 2, batch_perm_inds_exp) |
| | return reconstructed_labels |
| |
|
| |
|
| | class TestSortingUtils: |
| | @pytest.mark.unit |
| | @pytest.mark.parametrize( |
| | "mat, max_cap_val, thres, expected", |
| | [ |
| | |
| | (torch.tensor([[0.1, 0.6, 0.0], [0.0, 0.0, 0.9]]), -1, 0.5, torch.tensor([1, 2])), |
| | |
| | (torch.tensor([[0.1, 0.2], [0.3, 0.4]]), -1, 0.5, torch.tensor([-1, -1])), |
| | |
| | (torch.tensor([[0.0, 0.0], [0.0, 0.0]]), -1, 0.5, torch.tensor([-1, -1])), |
| | |
| | (torch.tensor([[0.1, 0.7, 0.3], [0.0, 0.0, 0.9], [0.5, 0.6, 0.7]]), -1, 0.5, torch.tensor([1, 2, 0])), |
| | |
| | (torch.tensor([[0.0, 0.0, 0.6]]), -1, 0.5, torch.tensor([2])), |
| | |
| | (torch.tensor([[0.1], [0.6], [0.0]]), -1, 0.5, torch.tensor([-1, 0, -1])), |
| | |
| | (torch.tensor([[0.501]]), -1, 0.5, torch.tensor([0], dtype=torch.long)), |
| | |
| | (torch.tensor([[0.0, 0.0], [0.0, 0.0]]), -1, 0.5, torch.tensor([-1, -1])), |
| | |
| | (torch.tensor([[0.6, 0.7], [0.8, 0.9]]), -1, 0.5, torch.tensor([0, 0])), |
| | |
| | (torch.tensor([[0.0, 0.0], [0.0, 0.0]]), 99, 0.5, torch.tensor([99, 99])), |
| | |
| | (torch.cat([torch.zeros(1, 100), torch.ones(1, 1)], dim=1), -1, 0.5, torch.tensor([100])), |
| | |
| | ( |
| | torch.cat([torch.zeros(1, 499), torch.tensor([[0.6]]), torch.zeros(1, 500)], dim=1), |
| | -1, |
| | 0.5, |
| | torch.tensor([499]), |
| | ), |
| | ], |
| | ) |
| | def test_find_first_nonzero(self, mat, max_cap_val, thres, expected): |
| | result = find_first_nonzero(mat, max_cap_val, thres) |
| | assert torch.equal(result, expected), f"Expected {expected} but got {result}" |
| |
|
| | @pytest.mark.unit |
| | @pytest.mark.parametrize( |
| | "match_score, speaker_permutations, expected", |
| | [ |
| | |
| | ( |
| | torch.tensor([[0.1, 0.9, 0.2]]), |
| | torch.tensor([[0, 1], [1, 0], [0, 1]]), |
| | torch.tensor([[1, 0]]), |
| | ), |
| | |
| | ( |
| | torch.tensor([[0.5, 0.3, 0.7], [0.2, 0.6, 0.4]]), |
| | torch.tensor([[0, 1], [1, 0], [0, 1]]), |
| | torch.tensor([[0, 1], [1, 0]]), |
| | ), |
| | |
| | ( |
| | torch.tensor( |
| | [[0.1, 0.4, 0.9, 0.5], [0.6, 0.3, 0.7, 0.2]] |
| | ), |
| | torch.tensor( |
| | [[0, 1, 2], [1, 0, 2], [2, 1, 0], [1, 2, 0]] |
| | ), |
| | torch.tensor([[2, 1, 0], [2, 1, 0]]), |
| | ), |
| | |
| | ( |
| | torch.tensor([[0.5, 0.5, 0.5], [0.5, 0.5, 0.5]]), |
| | torch.tensor([[0, 1], [1, 0], [0, 1]]), |
| | torch.tensor([[0, 1], [0, 1]]), |
| | ), |
| | |
| | ( |
| | torch.tensor([[0.8, 0.2]]), |
| | torch.tensor([[0], [0]]), |
| | torch.tensor([[0]]), |
| | ), |
| | |
| | ( |
| | torch.tensor([[0.3, 0.6], [0.4, 0.1], [0.2, 0.7]]), |
| | torch.tensor([[0, 1], [1, 0]]), |
| | torch.tensor([[1, 0], [0, 1], [1, 0]]), |
| | ), |
| | ], |
| | ) |
| | def test_find_best_permutation(self, match_score, speaker_permutations, expected): |
| | result = find_best_permutation(match_score, speaker_permutations) |
| | assert torch.equal(result, expected), f"Expected {expected} but got {result}" |
| |
|
| | @pytest.mark.parametrize( |
| | "batch_size, num_frames, num_speakers", |
| | [ |
| | (2, 4, 3), |
| | (3, 5, 2), |
| | (1, 6, 4), |
| | (5, 3, 5), |
| | ], |
| | ) |
| | def test_reconstruct_labels_with_forloop_ver(self, batch_size, num_frames, num_speakers): |
| | |
| | labels = torch.rand(batch_size, num_frames, num_speakers) |
| | batch_perm_inds = torch.stack([torch.randperm(num_speakers) for _ in range(batch_size)]) |
| |
|
| | |
| | result_matrix = reconstruct_labels(labels, batch_perm_inds) |
| | result_forloop = reconstruct_labels_forloop(labels, batch_perm_inds) |
| |
|
| | |
| | assert torch.allclose(result_matrix, result_forloop), "The results are not equal!" |
| |
|
| | @pytest.mark.parametrize( |
| | "labels, batch_perm_inds, expected_output", |
| | [ |
| | |
| | ( |
| | torch.tensor( |
| | [ |
| | [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]], |
| | [[0.9, 0.8, 0.7], [0.6, 0.5, 0.4], [0.3, 0.2, 0.1]], |
| | ] |
| | ), |
| | torch.tensor([[2, 0, 1], [1, 2, 0]]), |
| | torch.tensor( |
| | [ |
| | [[0.3, 0.1, 0.2], [0.6, 0.4, 0.5], [0.9, 0.7, 0.8]], |
| | [[0.8, 0.7, 0.9], [0.5, 0.4, 0.6], [0.2, 0.1, 0.3]], |
| | ] |
| | ), |
| | ), |
| | |
| | ( |
| | torch.tensor( |
| | [[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2], [1.3, 1.4, 1.5, 1.6]]] |
| | ), |
| | torch.tensor([[3, 0, 1, 2]]), |
| | torch.tensor( |
| | [[[0.4, 0.1, 0.2, 0.3], [0.8, 0.5, 0.6, 0.7], [1.2, 0.9, 1.0, 1.1], [1.6, 1.3, 1.4, 1.5]]] |
| | ), |
| | ), |
| | |
| | ( |
| | torch.tensor( |
| | [ |
| | [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]], |
| | [[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]], |
| | [[1.3, 1.4], [1.5, 1.6], [1.7, 1.8]], |
| | [[1.9, 2.0], [2.1, 2.2], [2.3, 2.4]], |
| | ] |
| | ), |
| | torch.tensor([[1, 0], [0, 1], [1, 0], [0, 1]]), |
| | torch.tensor( |
| | [ |
| | [[0.2, 0.1], [0.4, 0.3], [0.6, 0.5]], |
| | [[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]], |
| | [[1.4, 1.3], [1.6, 1.5], [1.8, 1.7]], |
| | [[1.9, 2.0], [2.1, 2.2], [2.3, 2.4]], |
| | ] |
| | ), |
| | ), |
| | ], |
| | ) |
| | def test_reconstruct_labels(self, labels, batch_perm_inds, expected_output): |
| | |
| | result = reconstruct_labels(labels, batch_perm_inds) |
| | |
| | assert torch.allclose(result, expected_output), f"Expected {expected_output}, but got {result}" |
| |
|
| |
|
| | class TestTargetGenerators: |
| |
|
| | @pytest.mark.parametrize( |
| | "labels, preds, num_speakers, expected_output", |
| | [ |
| | |
| | ( |
| | torch.tensor( |
| | [ |
| | [[0.9, 0.1, 0.0], [0.1, 0.8, 0.0], [0.0, 0.1, 0.9]], |
| | [[0.0, 0.0, 0.9], [0.0, 0.9, 0.1], [0.9, 0.1, 0.0]], |
| | ] |
| | ), |
| | torch.tensor( |
| | [ |
| | [[0.8, 0.2, 0.0], [0.2, 0.7, 0.0], [0.0, 0.1, 0.9]], |
| | [[0.0, 0.0, 0.8], [0.0, 0.8, 0.2], [0.9, 0.1, 0.0]], |
| | ] |
| | ), |
| | 3, |
| | torch.tensor( |
| | [ |
| | [[0.9, 0.1, 0.0], [0.1, 0.8, 0.0], [0.0, 0.1, 0.9]], |
| | [[0.9, 0.0, 0.0], [0.1, 0.9, 0.0], [0.0, 0.1, 0.9]], |
| | ] |
| | ), |
| | ), |
| | |
| | ( |
| | torch.tensor([[[0.9, 0.8, 0.7], [0.2, 0.8, 0.7], [0.2, 0.3, 0.9]]]), |
| | torch.tensor([[[0.6, 0.7, 0.2], [0.9, 0.4, 0.0], [0.1, 0.7, 0.1]]]), |
| | 3, |
| | torch.tensor([[[0.8, 0.7, 0.9], [0.8, 0.7, 0.2], [0.3, 0.9, 0.2]]]), |
| | ), |
| | |
| | ( |
| | torch.tensor([[[0, 0, 1, 1], [0, 0, 1, 1], [0, 0, 0, 1], [0, 0, 0, 0]]]), |
| | torch.tensor( |
| | [[[0.6, 0.6, 0.1, 0.9], [0.7, 0.7, 0.2, 0.8], [0.4, 0.6, 0.2, 0.7], [0.1, 0.1, 0.1, 0.7]]] |
| | ), |
| | 4, |
| | torch.tensor([[[1, 1, 0, 0], [1, 1, 0, 0], [0, 1, 0, 0], [0, 0, 0, 0]]]), |
| | ), |
| | ], |
| | ) |
| | def test_get_ats_targets(self, labels, preds, num_speakers, expected_output): |
| | |
| | speaker_inds = list(range(num_speakers)) |
| | speaker_permutations = torch.tensor(list(itertools.permutations(speaker_inds))) |
| |
|
| | |
| | result = get_ats_targets(labels, preds, speaker_permutations) |
| | |
| | assert torch.allclose(result, expected_output), f"Expected {expected_output}, but got {result}" |
| |
|
| | @pytest.mark.unit |
| | @pytest.mark.parametrize( |
| | "labels, preds, num_speakers, expected_output", |
| | [ |
| | |
| | ( |
| | torch.tensor( |
| | [[[1, 0], [0, 1]], [[1, 0], [0, 1]]] |
| | ), |
| | torch.tensor( |
| | [[[1, 0], [0, 1]], [[0, 1], [1, 0]]] |
| | ), |
| | 2, |
| | torch.tensor([[[1, 0], [0, 1]], [[0, 1], [1, 0]]]), |
| | ), |
| | |
| | ( |
| | torch.tensor([[[0.8, 0.2], [0.3, 0.7]]]), |
| | torch.tensor([[[0.9, 0.1], [0.2, 0.8]]]), |
| | 2, |
| | torch.tensor( |
| | [[[0.8, 0.2], [0.3, 0.7]]] |
| | ), |
| | ), |
| | |
| | ( |
| | torch.tensor([[[0, 0, 1, 1], [0, 0, 1, 1], [0, 0, 0, 1], [0, 0, 0, 0]]]), |
| | torch.tensor( |
| | [[[0.61, 0.6, 0.1, 0.9], [0.7, 0.7, 0.2, 0.8], [0.4, 0.6, 0.2, 0.7], [0.1, 0.1, 0.1, 0.7]]] |
| | ), |
| | 4, |
| | torch.tensor([[[1, 0, 0, 1], [1, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 0]]]), |
| | ), |
| | ], |
| | ) |
| | def test_get_pil_targets(self, labels, preds, num_speakers, expected_output): |
| | |
| | speaker_inds = list(range(num_speakers)) |
| | speaker_permutations = torch.tensor(list(itertools.permutations(speaker_inds))) |
| |
|
| | result = get_pil_targets(labels, preds, speaker_permutations) |
| | assert torch.equal(result, expected_output), f"Expected {expected_output} but got {result}" |
| |
|
| |
|
| | class TestGetHiddenLengthFromSampleLength: |
| | @pytest.mark.parametrize( |
| | "num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame, expected_hidden_length", |
| | [ |
| | (160, 160, 8, 1), |
| | (1280, 160, 8, 1), |
| | (0, 160, 8, 0), |
| | (159, 160, 8, 1), |
| | (129, 100, 5, 1), |
| | (300, 150, 3, 1), |
| | ], |
| | ) |
| | def test_various_cases( |
| | self, num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame, expected_hidden_length |
| | ): |
| | result = get_hidden_length_from_sample_length( |
| | num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame |
| | ) |
| | assert result == expected_hidden_length |
| |
|
| | def test_default_parameters(self): |
| | assert get_hidden_length_from_sample_length(160) == 1 |
| | assert get_hidden_length_from_sample_length(1280) == 1 |
| | assert get_hidden_length_from_sample_length(0) == 0 |
| | assert get_hidden_length_from_sample_length(159) == 1 |
| |
|
| | def test_edge_cases(self): |
| | assert get_hidden_length_from_sample_length(159, 160, 8) == 1 |
| | assert get_hidden_length_from_sample_length(160, 160, 8) == 1 |
| | assert get_hidden_length_from_sample_length(161, 160, 8) == 1 |
| | assert get_hidden_length_from_sample_length(1279, 160, 8) == 1 |
| |
|
| | def test_real_life_examples(self): |
| | |
| | assert get_hidden_length_from_sample_length(160000) == 125 |
| | assert get_hidden_length_from_sample_length(159999) == 125 |
| | assert get_hidden_length_from_sample_length(158720) == 124 |
| | assert get_hidden_length_from_sample_length(158719) == 124 |
| |
|
| | assert get_hidden_length_from_sample_length(158880) == 125 |
| | assert get_hidden_length_from_sample_length(158879) == 125 |
| | assert get_hidden_length_from_sample_length(1600) == 2 |
| | assert get_hidden_length_from_sample_length(1599) == 2 |
| |
|