| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import pytest |
| | import torch |
| |
|
| | from nemo.collections.common.parts.utils import mask_sequence_tensor |
| | from nemo.collections.tts.modules.audio_codec_modules import ( |
| | CodecActivation, |
| | Conv1dNorm, |
| | ConvTranspose1dNorm, |
| | FiniteScalarQuantizer, |
| | GroupFiniteScalarQuantizer, |
| | HiFiGANDecoder, |
| | MelSpectrogramProcessor, |
| | MultiBandMelEncoder, |
| | ResidualBlock, |
| | ResNetEncoder, |
| | get_down_sample_padding, |
| | ) |
| | from nemo.collections.tts.modules.encodec_modules import GroupResidualVectorQuantizer, ResidualVectorQuantizer |
| |
|
| |
|
| | class TestAudioCodecModules: |
| | def setup_class(self): |
| | self.in_channels = 8 |
| | self.out_channels = 16 |
| | self.filters = 32 |
| | self.batch_size = 2 |
| | self.len1 = 4 |
| | self.len2 = 8 |
| | self.max_len = 10 |
| | self.kernel_size = 3 |
| |
|
| | @pytest.mark.run_only_on('CPU') |
| | @pytest.mark.unit |
| | def test_conv1d(self): |
| | inputs = torch.rand([self.batch_size, self.in_channels, self.max_len]) |
| | lengths = torch.tensor([self.len1, self.len2], dtype=torch.int32) |
| |
|
| | conv = Conv1dNorm(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=self.kernel_size) |
| | out = conv(inputs=inputs, input_len=lengths) |
| |
|
| | assert out.shape == (self.batch_size, self.out_channels, self.max_len) |
| | assert torch.all(out[0, :, : self.len1] != 0.0) |
| | assert torch.all(out[0, :, self.len1 :] == 0.0) |
| | assert torch.all(out[1, :, : self.len2] != 0.0) |
| | assert torch.all(out[1, :, self.len2 :] == 0.0) |
| |
|
| | @pytest.mark.run_only_on('CPU') |
| | @pytest.mark.unit |
| | def test_conv1d_downsample(self): |
| | stride = 2 |
| | out_len = self.max_len // stride |
| | out_len_1 = self.len1 // stride |
| | out_len_2 = self.len2 // stride |
| | inputs = torch.rand([self.batch_size, self.in_channels, self.max_len]) |
| | lengths = torch.tensor([out_len_1, out_len_2], dtype=torch.int32) |
| |
|
| | padding = get_down_sample_padding(kernel_size=self.kernel_size, stride=stride) |
| | conv = Conv1dNorm( |
| | in_channels=self.in_channels, |
| | out_channels=self.out_channels, |
| | kernel_size=self.kernel_size, |
| | stride=stride, |
| | padding=padding, |
| | ) |
| | out = conv(inputs=inputs, input_len=lengths) |
| |
|
| | assert out.shape == (self.batch_size, self.out_channels, out_len) |
| | assert torch.all(out[0, :, :out_len_1] != 0.0) |
| | assert torch.all(out[0, :, out_len_1:] == 0.0) |
| | assert torch.all(out[1, :, :out_len_2] != 0.0) |
| | assert torch.all(out[1, :, out_len_2:] == 0.0) |
| |
|
| | @pytest.mark.run_only_on('CPU') |
| | @pytest.mark.unit |
| | def test_conv1d_transpose_upsample(self): |
| | stride = 2 |
| | out_len = self.max_len * stride |
| | out_len_1 = self.len1 * stride |
| | out_len_2 = self.len2 * stride |
| | inputs = torch.rand([self.batch_size, self.in_channels, self.max_len]) |
| | lengths = torch.tensor([out_len_1, out_len_2], dtype=torch.int32) |
| |
|
| | conv = ConvTranspose1dNorm( |
| | in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=self.kernel_size, stride=stride |
| | ) |
| | out = conv(inputs=inputs, input_len=lengths) |
| |
|
| | assert out.shape == (self.batch_size, self.out_channels, out_len) |
| | assert torch.all(out[0, :, :out_len_1] != 0.0) |
| | assert torch.all(out[0, :, out_len_1:] == 0.0) |
| | assert torch.all(out[1, :, :out_len_2] != 0.0) |
| | assert torch.all(out[1, :, out_len_2:] == 0.0) |
| |
|
| | @pytest.mark.run_only_on('CPU') |
| | @pytest.mark.unit |
| | def test_residual_block(self): |
| | lengths = torch.tensor([self.len1, self.len2], dtype=torch.int32) |
| | inputs = torch.rand([self.batch_size, self.in_channels, self.max_len]) |
| | inputs = mask_sequence_tensor(tensor=inputs, lengths=lengths) |
| |
|
| | res_block = ResidualBlock(channels=self.in_channels, filters=self.filters) |
| | out = res_block(inputs=inputs, input_len=lengths) |
| |
|
| | assert out.shape == (self.batch_size, self.in_channels, self.max_len) |
| | assert torch.all(out[0, :, : self.len1] != 0.0) |
| | assert torch.all(out[0, :, self.len1 :] == 0.0) |
| | assert torch.all(out[1, :, : self.len2] != 0.0) |
| | assert torch.all(out[1, :, self.len2 :] == 0.0) |
| |
|
| | @pytest.mark.run_only_on('CPU') |
| | @pytest.mark.unit |
| | def test_hifigan_decoder(self): |
| | up_sample_rates = [4, 4, 2, 2] |
| | up_sample_total = 64 |
| | lengths = torch.tensor([self.len1, self.len2], dtype=torch.int32) |
| | out_len_1 = self.len1 * up_sample_total |
| | out_len_2 = self.len2 * up_sample_total |
| | out_len_max = self.max_len * up_sample_total |
| |
|
| | inputs = torch.rand([self.batch_size, self.in_channels, self.max_len]) |
| | inputs = mask_sequence_tensor(tensor=inputs, lengths=lengths) |
| |
|
| | decoder = HiFiGANDecoder( |
| | input_dim=self.in_channels, base_channels=self.filters, up_sample_rates=up_sample_rates |
| | ) |
| | out, out_len = decoder(inputs=inputs, input_len=lengths) |
| |
|
| | assert out_len[0] == out_len_1 |
| | assert out_len[1] == out_len_2 |
| | assert out.shape == (self.batch_size, out_len_max) |
| | assert torch.all(out[0, :out_len_1] != 0.0) |
| | assert torch.all(out[0, out_len_1:] == 0.0) |
| | assert torch.all(out[1, :out_len_2] != 0.0) |
| | assert torch.all(out[1, out_len_2:] == 0.0) |
| |
|
| | @pytest.mark.run_only_on('CPU') |
| | @pytest.mark.unit |
| | def test_resnet_encoder(self): |
| | lengths = torch.tensor([self.len1, self.len2], dtype=torch.int32) |
| | inputs = torch.rand([self.batch_size, self.in_channels, self.max_len]) |
| | inputs = mask_sequence_tensor(tensor=inputs, lengths=lengths) |
| |
|
| | res_net = ResNetEncoder(in_channels=self.in_channels, out_channels=self.out_channels) |
| | out = res_net(inputs=inputs, input_len=lengths) |
| |
|
| | assert out.shape == (self.batch_size, self.out_channels, self.max_len) |
| | assert torch.all(out[0, :, : self.len1] != 0.0) |
| | assert torch.all(out[0, :, self.len1 :] == 0.0) |
| | assert torch.all(out[1, :, : self.len2] != 0.0) |
| | assert torch.all(out[1, :, self.len2 :] == 0.0) |
| |
|
| | @pytest.mark.run_only_on('CPU') |
| | @pytest.mark.unit |
| | def test_multiband_mel_encoder(self): |
| | mel_dim = 10 |
| | win_length = 16 |
| | hop_length = 10 |
| | mel_bands = [(0, 4), (4, 7), (7, 10)] |
| | max_len = 100 |
| | len1 = 40 |
| | len2 = 80 |
| | out_dim = len(mel_bands) * self.out_channels |
| | lengths = torch.tensor([len1, len2], dtype=torch.int32) |
| | out_len_1 = len1 // hop_length - 1 |
| | out_len_2 = len2 // hop_length - 1 |
| | out_len_max = max_len // hop_length |
| |
|
| | audio = torch.rand([self.batch_size, max_len]) |
| | audio = mask_sequence_tensor(tensor=audio, lengths=lengths) |
| |
|
| | mel_processor = MelSpectrogramProcessor( |
| | mel_dim=mel_dim, sample_rate=100, win_length=win_length, hop_length=hop_length |
| | ) |
| | encoder = MultiBandMelEncoder( |
| | mel_bands=mel_bands, |
| | mel_processor=mel_processor, |
| | out_channels=self.out_channels, |
| | hidden_channels=self.filters, |
| | filters=self.filters, |
| | ) |
| | out, out_len = encoder(audio=audio, audio_len=lengths) |
| |
|
| | assert out_len[0] == out_len_1 |
| | assert out_len[1] == out_len_2 |
| | assert out.shape == (self.batch_size, out_dim, out_len_max) |
| | assert torch.all(out[0, :, :out_len_1] != 0.0) |
| | assert torch.all(out[0, :, out_len_1:] == 0.0) |
| | assert torch.all(out[1, :, :out_len_2] != 0.0) |
| | assert torch.all(out[1, :, out_len_2:] == 0.0) |
| |
|
| |
|
| | class TestResidualVectorQuantizer: |
| | def setup_class(self): |
| | """Setup common members""" |
| | self.batch_size = 2 |
| | self.max_len = 20 |
| | self.codebook_size = 256 |
| | self.codebook_dim = 64 |
| | self.num_examples = 10 |
| |
|
| | @pytest.mark.unit |
| | @pytest.mark.parametrize('num_codebooks', [1, 4]) |
| | def test_rvq_eval(self, num_codebooks: int): |
| | """Simple test to confirm that the RVQ module can be instantiated and run, |
| | and that forward produces the same result as encode-decode. |
| | """ |
| | |
| | rvq = ResidualVectorQuantizer(num_codebooks=num_codebooks, codebook_dim=self.codebook_dim) |
| | rvq.eval() |
| |
|
| | for i in range(self.num_examples): |
| | inputs = torch.randn([self.batch_size, self.codebook_dim, self.max_len]) |
| | input_len = torch.tensor([self.max_len] * self.batch_size, dtype=torch.int32) |
| |
|
| | |
| | dequantized_fw, indices_fw, commit_loss = rvq(inputs=inputs, input_len=input_len) |
| |
|
| | |
| | assert commit_loss == 0.0, f'example {i}: commit_loss is {commit_loss}, expected 0.0' |
| |
|
| | |
| | indices_enc = rvq.encode(inputs=inputs, input_len=input_len) |
| | dequantized_dec = rvq.decode(indices=indices_enc, input_len=input_len) |
| |
|
| | |
| | torch.testing.assert_close(indices_enc, indices_fw, msg=f'example {i}: indices mismatch') |
| | torch.testing.assert_close(dequantized_dec, dequantized_fw, msg=f'example {i}: dequantized mismatch') |
| |
|
| | @pytest.mark.unit |
| | @pytest.mark.parametrize('num_groups', [1, 2, 4]) |
| | @pytest.mark.parametrize('num_codebooks', [1, 4]) |
| | def test_group_rvq_eval(self, num_groups: int, num_codebooks: int): |
| | """Simple test to confirm that the group RVQ module can be instantiated and run, |
| | and that forward produces the same result as encode-decode. |
| | """ |
| | if num_groups > num_codebooks: |
| | |
| | with pytest.raises(ValueError): |
| | _ = GroupResidualVectorQuantizer( |
| | num_codebooks=num_codebooks, num_groups=num_groups, codebook_dim=self.codebook_dim |
| | ) |
| | else: |
| | |
| | |
| | grvq = GroupResidualVectorQuantizer( |
| | num_codebooks=num_codebooks, num_groups=num_groups, codebook_dim=self.codebook_dim |
| | ) |
| | grvq.eval() |
| |
|
| | for i in range(self.num_examples): |
| | inputs = torch.randn([self.batch_size, self.codebook_dim, self.max_len]) |
| | input_len = torch.tensor([self.max_len] * self.batch_size, dtype=torch.int32) |
| |
|
| | |
| | dequantized_fw, indices_fw, commit_loss = grvq(inputs=inputs, input_len=input_len) |
| |
|
| | |
| | assert commit_loss == 0.0, f'example {i}: commit_loss is {commit_loss}, expected 0.0' |
| |
|
| | |
| | indices_enc = grvq.encode(inputs=inputs, input_len=input_len) |
| | dequantized_dec = grvq.decode(indices=indices_enc, input_len=input_len) |
| |
|
| | |
| | torch.testing.assert_close(indices_enc, indices_fw, msg=f'example {i}: indices mismatch') |
| | torch.testing.assert_close(dequantized_dec, dequantized_fw, msg=f'example {i}: dequantized mismatch') |
| |
|
| | |
| | inputs_grouped = inputs.chunk(num_groups, dim=1) |
| | dequantized_fw_grouped = dequantized_fw.chunk(num_groups, dim=1) |
| | indices_fw_grouped = indices_fw.chunk(num_groups, dim=0) |
| |
|
| | for g in range(num_groups): |
| | dequantized, indices, _ = grvq.rvqs[g](inputs=inputs_grouped[g], input_len=input_len) |
| | torch.testing.assert_close( |
| | dequantized, dequantized_fw_grouped[g], msg=f'example {i}: dequantized mismatch for group {g}' |
| | ) |
| | torch.testing.assert_close( |
| | indices, indices_fw_grouped[g], msg=f'example {i}: indices mismatch for group {g}' |
| | ) |
| |
|
| |
|
| | class TestCodecActivation: |
| | def setup_class(self): |
| | self.batch_size = 2 |
| | self.in_channels = 4 |
| | self.max_len = 4 |
| |
|
| | @pytest.mark.run_only_on('CPU') |
| | @pytest.mark.unit |
| | def test_snake(self): |
| | """ |
| | Test for snake activation function execution. |
| | """ |
| | inputs = torch.rand([self.batch_size, self.in_channels, self.max_len]) |
| | snake = CodecActivation('snake', channels=self.in_channels) |
| | out = snake(x=inputs) |
| | assert out.shape == (self.batch_size, self.in_channels, self.max_len) |
| |
|
| |
|
| | class TestFiniteScalarQuantizer: |
| | def setup_class(self): |
| | """Setup common members""" |
| | self.batch_size = 2 |
| | self.max_len = 20 |
| | self.num_examples = 10 |
| |
|
| | @pytest.mark.unit |
| | @pytest.mark.parametrize('num_levels', [[2, 3], [8, 5, 5]]) |
| | def test_fsq_eval(self, num_levels: list): |
| | """Simple test to confirm that the FSQ module can be instantiated and run, |
| | and that forward produces the same result as encode-decode. |
| | """ |
| | fsq = FiniteScalarQuantizer(num_levels=num_levels) |
| |
|
| | for i in range(self.num_examples): |
| | inputs = torch.randn([self.batch_size, fsq.codebook_dim, self.max_len]) |
| | input_len = torch.tensor([self.max_len] * self.batch_size, dtype=torch.int32) |
| |
|
| | |
| | dequantized_fw, indices_fw = fsq(inputs=inputs, input_len=input_len) |
| |
|
| | assert dequantized_fw.max() <= 1.0, f'example {i}: dequantized_fw.max() is {dequantized_fw.max()}' |
| | assert dequantized_fw.min() >= -1.0, f'example {i}: dequantized_fw.min() is {dequantized_fw.min()}' |
| |
|
| | |
| | indices_enc = fsq.encode(inputs=inputs, input_len=input_len) |
| | dequantized_dec = fsq.decode(indices=indices_enc, input_len=input_len) |
| |
|
| | |
| | torch.testing.assert_close(indices_enc, indices_fw, msg=f'example {i}: indices mismatch') |
| | torch.testing.assert_close(dequantized_dec, dequantized_fw, msg=f'example {i}: dequantized mismatch') |
| |
|
| | @pytest.mark.unit |
| | def test_fsq_output(self): |
| | """Simple test to make sure the output of FSQ is correct for a single setup. |
| | |
| | To re-generate test vectors: |
| | ``` |
| | num_examples, max_len = 5, 8 |
| | inputs = torch.randn([num_examples, fsq.codebook_dim, max_len]) |
| | input_len = torch.tensor([max_len] * num_examples, dtype=torch.int32) |
| | dequantized, indices = fsq(inputs=inputs, input_len=input_len) |
| | ``` |
| | """ |
| | num_levels = [3, 4] |
| | fsq = FiniteScalarQuantizer(num_levels=num_levels) |
| |
|
| | |
| | inputs = torch.tensor( |
| | [ |
| | [ |
| | [0.1483, -0.3855, -0.3715, -0.5913, -0.2212, -0.4226, -0.4864, -1.6069], |
| | [-0.5519, -0.5307, -0.5995, -1.9675, -0.4439, 0.3938, -0.5636, -0.3655], |
| | ], |
| | [ |
| | [0.5184, 1.4028, 0.1553, -0.2324, 1.0363, -0.4981, -0.1203, -1.0335], |
| | [-0.1567, -0.2274, 0.0424, -0.0819, -0.2122, -2.1851, -1.5035, -1.2237], |
| | ], |
| | [ |
| | [0.9497, 0.8510, -1.2021, 0.3299, -0.2388, 0.8445, 2.2129, -2.3383], |
| | [1.5331, 0.0399, -0.7676, -0.4715, -0.5713, 0.8761, -0.9755, -0.7479], |
| | ], |
| | [ |
| | [1.7243, -1.2146, -0.1969, 1.9261, 0.1109, 0.4028, 0.1240, -0.0994], |
| | [-0.3304, 2.1239, 0.1004, -1.4060, 1.1463, -0.0557, -0.5856, -1.2441], |
| | ], |
| | [ |
| | [2.3743, -0.1421, -0.4548, 0.6320, -0.2640, -0.3967, -2.5694, 0.0493], |
| | [0.3409, 0.2366, -0.0309, -0.7652, 0.3484, -0.8419, 0.9079, -0.9929], |
| | ], |
| | ] |
| | ) |
| |
|
| | input_len = torch.tensor([8, 8, 8, 8, 8], dtype=torch.int32) |
| |
|
| | |
| | dequantized_expected = torch.tensor( |
| | [ |
| | [ |
| | [0.0000, 0.0000, 0.0000, -1.0000, 0.0000, 0.0000, 0.0000, -1.0000], |
| | [-0.5000, -0.5000, -0.5000, -1.0000, -0.5000, 0.0000, -0.5000, -0.5000], |
| | ], |
| | [ |
| | [0.0000, 1.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, -1.0000], |
| | [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, -1.0000, -1.0000, -1.0000], |
| | ], |
| | [ |
| | [1.0000, 1.0000, -1.0000, 0.0000, 0.0000, 1.0000, 1.0000, -1.0000], |
| | [0.5000, 0.0000, -0.5000, -0.5000, -0.5000, 0.5000, -0.5000, -0.5000], |
| | ], |
| | [ |
| | [1.0000, -1.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000], |
| | [0.0000, 0.5000, 0.0000, -1.0000, 0.5000, 0.0000, -0.5000, -1.0000], |
| | ], |
| | [ |
| | [1.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, -1.0000, 0.0000], |
| | [0.0000, 0.0000, 0.0000, -0.5000, 0.0000, -0.5000, 0.5000, -0.5000], |
| | ], |
| | ] |
| | ) |
| |
|
| | indices_expected = torch.tensor( |
| | [ |
| | [ |
| | [4, 4, 4, 0, 4, 7, 4, 3], |
| | [7, 8, 7, 7, 8, 1, 1, 0], |
| | [11, 8, 3, 4, 4, 11, 5, 3], |
| | [8, 9, 7, 2, 10, 7, 4, 1], |
| | [8, 7, 7, 5, 7, 4, 9, 4], |
| | ] |
| | ], |
| | dtype=torch.int32, |
| | ) |
| |
|
| | |
| | dequantized, indices = fsq(inputs=inputs, input_len=input_len) |
| | torch.testing.assert_close(dequantized, dequantized_expected, msg=f'dequantized mismatch') |
| | torch.testing.assert_close(indices, indices_expected, msg=f'indices mismatch') |
| |
|
| | @pytest.mark.unit |
| | @pytest.mark.parametrize('num_groups', [1, 2, 4]) |
| | @pytest.mark.parametrize('num_levels_per_group', [[2, 3], [8, 5, 5]]) |
| | def test_group_fsq_eval(self, num_groups: int, num_levels_per_group: int): |
| | """Simple test to confirm that the group FSQ module can be instantiated and run, |
| | and that forward produces the same result as encode-decode. |
| | """ |
| | |
| | |
| | gfsq = GroupFiniteScalarQuantizer(num_groups=num_groups, num_levels_per_group=num_levels_per_group) |
| |
|
| | for i in range(self.num_examples): |
| | inputs = torch.randn([self.batch_size, gfsq.codebook_dim, self.max_len]) |
| | input_len = torch.tensor([self.max_len] * self.batch_size, dtype=torch.int32) |
| |
|
| | |
| | dequantized_fw, indices_fw = gfsq(inputs=inputs, input_len=input_len) |
| |
|
| | |
| | indices_enc = gfsq.encode(inputs=inputs, input_len=input_len) |
| | dequantized_dec = gfsq.decode(indices=indices_enc, input_len=input_len) |
| |
|
| | |
| | torch.testing.assert_close(indices_enc, indices_fw, msg=f'example {i}: indices mismatch') |
| | torch.testing.assert_close(dequantized_dec, dequantized_fw, msg=f'example {i}: dequantized mismatch') |
| |
|
| | |
| | inputs_grouped = inputs.chunk(num_groups, dim=1) |
| | dequantized_fw_grouped = dequantized_fw.chunk(num_groups, dim=1) |
| | indices_fw_grouped = indices_fw.chunk(num_groups, dim=0) |
| |
|
| | for g in range(num_groups): |
| | dequantized, indices = gfsq.fsqs[g](inputs=inputs_grouped[g], input_len=input_len) |
| | torch.testing.assert_close( |
| | dequantized, dequantized_fw_grouped[g], msg=f'example {i}: dequantized mismatch for group {g}' |
| | ) |
| | torch.testing.assert_close( |
| | indices, indices_fw_grouped[g], msg=f'example {i}: indices mismatch for group {g}' |
| | ) |
| |
|