NeMo_Canary / tests /collections /tts /modules /test_audio_codec_modules.py
Respair's picture
Upload folder using huggingface_hub
b386992 verified
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from nemo.collections.common.parts.utils import mask_sequence_tensor
from nemo.collections.tts.modules.audio_codec_modules import (
CodecActivation,
Conv1dNorm,
ConvTranspose1dNorm,
FiniteScalarQuantizer,
GroupFiniteScalarQuantizer,
HiFiGANDecoder,
MelSpectrogramProcessor,
MultiBandMelEncoder,
ResidualBlock,
ResNetEncoder,
get_down_sample_padding,
)
from nemo.collections.tts.modules.encodec_modules import GroupResidualVectorQuantizer, ResidualVectorQuantizer
class TestAudioCodecModules:
def setup_class(self):
self.in_channels = 8
self.out_channels = 16
self.filters = 32
self.batch_size = 2
self.len1 = 4
self.len2 = 8
self.max_len = 10
self.kernel_size = 3
@pytest.mark.run_only_on('CPU')
@pytest.mark.unit
def test_conv1d(self):
inputs = torch.rand([self.batch_size, self.in_channels, self.max_len])
lengths = torch.tensor([self.len1, self.len2], dtype=torch.int32)
conv = Conv1dNorm(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=self.kernel_size)
out = conv(inputs=inputs, input_len=lengths)
assert out.shape == (self.batch_size, self.out_channels, self.max_len)
assert torch.all(out[0, :, : self.len1] != 0.0)
assert torch.all(out[0, :, self.len1 :] == 0.0)
assert torch.all(out[1, :, : self.len2] != 0.0)
assert torch.all(out[1, :, self.len2 :] == 0.0)
@pytest.mark.run_only_on('CPU')
@pytest.mark.unit
def test_conv1d_downsample(self):
stride = 2
out_len = self.max_len // stride
out_len_1 = self.len1 // stride
out_len_2 = self.len2 // stride
inputs = torch.rand([self.batch_size, self.in_channels, self.max_len])
lengths = torch.tensor([out_len_1, out_len_2], dtype=torch.int32)
padding = get_down_sample_padding(kernel_size=self.kernel_size, stride=stride)
conv = Conv1dNorm(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=self.kernel_size,
stride=stride,
padding=padding,
)
out = conv(inputs=inputs, input_len=lengths)
assert out.shape == (self.batch_size, self.out_channels, out_len)
assert torch.all(out[0, :, :out_len_1] != 0.0)
assert torch.all(out[0, :, out_len_1:] == 0.0)
assert torch.all(out[1, :, :out_len_2] != 0.0)
assert torch.all(out[1, :, out_len_2:] == 0.0)
@pytest.mark.run_only_on('CPU')
@pytest.mark.unit
def test_conv1d_transpose_upsample(self):
stride = 2
out_len = self.max_len * stride
out_len_1 = self.len1 * stride
out_len_2 = self.len2 * stride
inputs = torch.rand([self.batch_size, self.in_channels, self.max_len])
lengths = torch.tensor([out_len_1, out_len_2], dtype=torch.int32)
conv = ConvTranspose1dNorm(
in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=self.kernel_size, stride=stride
)
out = conv(inputs=inputs, input_len=lengths)
assert out.shape == (self.batch_size, self.out_channels, out_len)
assert torch.all(out[0, :, :out_len_1] != 0.0)
assert torch.all(out[0, :, out_len_1:] == 0.0)
assert torch.all(out[1, :, :out_len_2] != 0.0)
assert torch.all(out[1, :, out_len_2:] == 0.0)
@pytest.mark.run_only_on('CPU')
@pytest.mark.unit
def test_residual_block(self):
lengths = torch.tensor([self.len1, self.len2], dtype=torch.int32)
inputs = torch.rand([self.batch_size, self.in_channels, self.max_len])
inputs = mask_sequence_tensor(tensor=inputs, lengths=lengths)
res_block = ResidualBlock(channels=self.in_channels, filters=self.filters)
out = res_block(inputs=inputs, input_len=lengths)
assert out.shape == (self.batch_size, self.in_channels, self.max_len)
assert torch.all(out[0, :, : self.len1] != 0.0)
assert torch.all(out[0, :, self.len1 :] == 0.0)
assert torch.all(out[1, :, : self.len2] != 0.0)
assert torch.all(out[1, :, self.len2 :] == 0.0)
@pytest.mark.run_only_on('CPU')
@pytest.mark.unit
def test_hifigan_decoder(self):
up_sample_rates = [4, 4, 2, 2]
up_sample_total = 64
lengths = torch.tensor([self.len1, self.len2], dtype=torch.int32)
out_len_1 = self.len1 * up_sample_total
out_len_2 = self.len2 * up_sample_total
out_len_max = self.max_len * up_sample_total
inputs = torch.rand([self.batch_size, self.in_channels, self.max_len])
inputs = mask_sequence_tensor(tensor=inputs, lengths=lengths)
decoder = HiFiGANDecoder(
input_dim=self.in_channels, base_channels=self.filters, up_sample_rates=up_sample_rates
)
out, out_len = decoder(inputs=inputs, input_len=lengths)
assert out_len[0] == out_len_1
assert out_len[1] == out_len_2
assert out.shape == (self.batch_size, out_len_max)
assert torch.all(out[0, :out_len_1] != 0.0)
assert torch.all(out[0, out_len_1:] == 0.0)
assert torch.all(out[1, :out_len_2] != 0.0)
assert torch.all(out[1, out_len_2:] == 0.0)
@pytest.mark.run_only_on('CPU')
@pytest.mark.unit
def test_resnet_encoder(self):
lengths = torch.tensor([self.len1, self.len2], dtype=torch.int32)
inputs = torch.rand([self.batch_size, self.in_channels, self.max_len])
inputs = mask_sequence_tensor(tensor=inputs, lengths=lengths)
res_net = ResNetEncoder(in_channels=self.in_channels, out_channels=self.out_channels)
out = res_net(inputs=inputs, input_len=lengths)
assert out.shape == (self.batch_size, self.out_channels, self.max_len)
assert torch.all(out[0, :, : self.len1] != 0.0)
assert torch.all(out[0, :, self.len1 :] == 0.0)
assert torch.all(out[1, :, : self.len2] != 0.0)
assert torch.all(out[1, :, self.len2 :] == 0.0)
@pytest.mark.run_only_on('CPU')
@pytest.mark.unit
def test_multiband_mel_encoder(self):
mel_dim = 10
win_length = 16
hop_length = 10
mel_bands = [(0, 4), (4, 7), (7, 10)]
max_len = 100
len1 = 40
len2 = 80
out_dim = len(mel_bands) * self.out_channels
lengths = torch.tensor([len1, len2], dtype=torch.int32)
out_len_1 = len1 // hop_length - 1
out_len_2 = len2 // hop_length - 1
out_len_max = max_len // hop_length
audio = torch.rand([self.batch_size, max_len])
audio = mask_sequence_tensor(tensor=audio, lengths=lengths)
mel_processor = MelSpectrogramProcessor(
mel_dim=mel_dim, sample_rate=100, win_length=win_length, hop_length=hop_length
)
encoder = MultiBandMelEncoder(
mel_bands=mel_bands,
mel_processor=mel_processor,
out_channels=self.out_channels,
hidden_channels=self.filters,
filters=self.filters,
)
out, out_len = encoder(audio=audio, audio_len=lengths)
assert out_len[0] == out_len_1
assert out_len[1] == out_len_2
assert out.shape == (self.batch_size, out_dim, out_len_max)
assert torch.all(out[0, :, :out_len_1] != 0.0)
assert torch.all(out[0, :, out_len_1:] == 0.0)
assert torch.all(out[1, :, :out_len_2] != 0.0)
assert torch.all(out[1, :, out_len_2:] == 0.0)
class TestResidualVectorQuantizer:
def setup_class(self):
"""Setup common members"""
self.batch_size = 2
self.max_len = 20
self.codebook_size = 256
self.codebook_dim = 64
self.num_examples = 10
@pytest.mark.unit
@pytest.mark.parametrize('num_codebooks', [1, 4])
def test_rvq_eval(self, num_codebooks: int):
"""Simple test to confirm that the RVQ module can be instantiated and run,
and that forward produces the same result as encode-decode.
"""
# instantiate and set in eval mode
rvq = ResidualVectorQuantizer(num_codebooks=num_codebooks, codebook_dim=self.codebook_dim)
rvq.eval()
for i in range(self.num_examples):
inputs = torch.randn([self.batch_size, self.codebook_dim, self.max_len])
input_len = torch.tensor([self.max_len] * self.batch_size, dtype=torch.int32)
# apply forward
dequantized_fw, indices_fw, commit_loss = rvq(inputs=inputs, input_len=input_len)
# make sure the commit loss is zero
assert commit_loss == 0.0, f'example {i}: commit_loss is {commit_loss}, expected 0.0'
# encode-decode
indices_enc = rvq.encode(inputs=inputs, input_len=input_len)
dequantized_dec = rvq.decode(indices=indices_enc, input_len=input_len)
# make sure the results are the same
torch.testing.assert_close(indices_enc, indices_fw, msg=f'example {i}: indices mismatch')
torch.testing.assert_close(dequantized_dec, dequantized_fw, msg=f'example {i}: dequantized mismatch')
@pytest.mark.unit
@pytest.mark.parametrize('num_groups', [1, 2, 4])
@pytest.mark.parametrize('num_codebooks', [1, 4])
def test_group_rvq_eval(self, num_groups: int, num_codebooks: int):
"""Simple test to confirm that the group RVQ module can be instantiated and run,
and that forward produces the same result as encode-decode.
"""
if num_groups > num_codebooks:
# Expected to fail if num_groups is lager than the total number of codebooks
with pytest.raises(ValueError):
_ = GroupResidualVectorQuantizer(
num_codebooks=num_codebooks, num_groups=num_groups, codebook_dim=self.codebook_dim
)
else:
# Test inference with group RVQ
# instantiate and set in eval mode
grvq = GroupResidualVectorQuantizer(
num_codebooks=num_codebooks, num_groups=num_groups, codebook_dim=self.codebook_dim
)
grvq.eval()
for i in range(self.num_examples):
inputs = torch.randn([self.batch_size, self.codebook_dim, self.max_len])
input_len = torch.tensor([self.max_len] * self.batch_size, dtype=torch.int32)
# apply forward
dequantized_fw, indices_fw, commit_loss = grvq(inputs=inputs, input_len=input_len)
# make sure the commit loss is zero
assert commit_loss == 0.0, f'example {i}: commit_loss is {commit_loss}, expected 0.0'
# encode-decode
indices_enc = grvq.encode(inputs=inputs, input_len=input_len)
dequantized_dec = grvq.decode(indices=indices_enc, input_len=input_len)
# make sure the results are the same
torch.testing.assert_close(indices_enc, indices_fw, msg=f'example {i}: indices mismatch')
torch.testing.assert_close(dequantized_dec, dequantized_fw, msg=f'example {i}: dequantized mismatch')
# apply individual RVQs and make sure the results are the same
inputs_grouped = inputs.chunk(num_groups, dim=1)
dequantized_fw_grouped = dequantized_fw.chunk(num_groups, dim=1)
indices_fw_grouped = indices_fw.chunk(num_groups, dim=0)
for g in range(num_groups):
dequantized, indices, _ = grvq.rvqs[g](inputs=inputs_grouped[g], input_len=input_len)
torch.testing.assert_close(
dequantized, dequantized_fw_grouped[g], msg=f'example {i}: dequantized mismatch for group {g}'
)
torch.testing.assert_close(
indices, indices_fw_grouped[g], msg=f'example {i}: indices mismatch for group {g}'
)
class TestCodecActivation:
def setup_class(self):
self.batch_size = 2
self.in_channels = 4
self.max_len = 4
@pytest.mark.run_only_on('CPU')
@pytest.mark.unit
def test_snake(self):
"""
Test for snake activation function execution.
"""
inputs = torch.rand([self.batch_size, self.in_channels, self.max_len])
snake = CodecActivation('snake', channels=self.in_channels)
out = snake(x=inputs)
assert out.shape == (self.batch_size, self.in_channels, self.max_len)
class TestFiniteScalarQuantizer:
def setup_class(self):
"""Setup common members"""
self.batch_size = 2
self.max_len = 20
self.num_examples = 10
@pytest.mark.unit
@pytest.mark.parametrize('num_levels', [[2, 3], [8, 5, 5]])
def test_fsq_eval(self, num_levels: list):
"""Simple test to confirm that the FSQ module can be instantiated and run,
and that forward produces the same result as encode-decode.
"""
fsq = FiniteScalarQuantizer(num_levels=num_levels)
for i in range(self.num_examples):
inputs = torch.randn([self.batch_size, fsq.codebook_dim, self.max_len])
input_len = torch.tensor([self.max_len] * self.batch_size, dtype=torch.int32)
# apply forward
dequantized_fw, indices_fw = fsq(inputs=inputs, input_len=input_len)
assert dequantized_fw.max() <= 1.0, f'example {i}: dequantized_fw.max() is {dequantized_fw.max()}'
assert dequantized_fw.min() >= -1.0, f'example {i}: dequantized_fw.min() is {dequantized_fw.min()}'
# encode-decode
indices_enc = fsq.encode(inputs=inputs, input_len=input_len)
dequantized_dec = fsq.decode(indices=indices_enc, input_len=input_len)
# make sure the results are the same
torch.testing.assert_close(indices_enc, indices_fw, msg=f'example {i}: indices mismatch')
torch.testing.assert_close(dequantized_dec, dequantized_fw, msg=f'example {i}: dequantized mismatch')
@pytest.mark.unit
def test_fsq_output(self):
"""Simple test to make sure the output of FSQ is correct for a single setup.
To re-generate test vectors:
```
num_examples, max_len = 5, 8
inputs = torch.randn([num_examples, fsq.codebook_dim, max_len])
input_len = torch.tensor([max_len] * num_examples, dtype=torch.int32)
dequantized, indices = fsq(inputs=inputs, input_len=input_len)
```
"""
num_levels = [3, 4]
fsq = FiniteScalarQuantizer(num_levels=num_levels)
# inputs
inputs = torch.tensor(
[
[
[0.1483, -0.3855, -0.3715, -0.5913, -0.2212, -0.4226, -0.4864, -1.6069],
[-0.5519, -0.5307, -0.5995, -1.9675, -0.4439, 0.3938, -0.5636, -0.3655],
],
[
[0.5184, 1.4028, 0.1553, -0.2324, 1.0363, -0.4981, -0.1203, -1.0335],
[-0.1567, -0.2274, 0.0424, -0.0819, -0.2122, -2.1851, -1.5035, -1.2237],
],
[
[0.9497, 0.8510, -1.2021, 0.3299, -0.2388, 0.8445, 2.2129, -2.3383],
[1.5331, 0.0399, -0.7676, -0.4715, -0.5713, 0.8761, -0.9755, -0.7479],
],
[
[1.7243, -1.2146, -0.1969, 1.9261, 0.1109, 0.4028, 0.1240, -0.0994],
[-0.3304, 2.1239, 0.1004, -1.4060, 1.1463, -0.0557, -0.5856, -1.2441],
],
[
[2.3743, -0.1421, -0.4548, 0.6320, -0.2640, -0.3967, -2.5694, 0.0493],
[0.3409, 0.2366, -0.0309, -0.7652, 0.3484, -0.8419, 0.9079, -0.9929],
],
]
)
input_len = torch.tensor([8, 8, 8, 8, 8], dtype=torch.int32)
# expected output
dequantized_expected = torch.tensor(
[
[
[0.0000, 0.0000, 0.0000, -1.0000, 0.0000, 0.0000, 0.0000, -1.0000],
[-0.5000, -0.5000, -0.5000, -1.0000, -0.5000, 0.0000, -0.5000, -0.5000],
],
[
[0.0000, 1.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, -1.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, -1.0000, -1.0000, -1.0000],
],
[
[1.0000, 1.0000, -1.0000, 0.0000, 0.0000, 1.0000, 1.0000, -1.0000],
[0.5000, 0.0000, -0.5000, -0.5000, -0.5000, 0.5000, -0.5000, -0.5000],
],
[
[1.0000, -1.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.5000, 0.0000, -1.0000, 0.5000, 0.0000, -0.5000, -1.0000],
],
[
[1.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, -1.0000, 0.0000],
[0.0000, 0.0000, 0.0000, -0.5000, 0.0000, -0.5000, 0.5000, -0.5000],
],
]
)
indices_expected = torch.tensor(
[
[
[4, 4, 4, 0, 4, 7, 4, 3],
[7, 8, 7, 7, 8, 1, 1, 0],
[11, 8, 3, 4, 4, 11, 5, 3],
[8, 9, 7, 2, 10, 7, 4, 1],
[8, 7, 7, 5, 7, 4, 9, 4],
]
],
dtype=torch.int32,
)
# test
dequantized, indices = fsq(inputs=inputs, input_len=input_len)
torch.testing.assert_close(dequantized, dequantized_expected, msg=f'dequantized mismatch')
torch.testing.assert_close(indices, indices_expected, msg=f'indices mismatch')
@pytest.mark.unit
@pytest.mark.parametrize('num_groups', [1, 2, 4])
@pytest.mark.parametrize('num_levels_per_group', [[2, 3], [8, 5, 5]])
def test_group_fsq_eval(self, num_groups: int, num_levels_per_group: int):
"""Simple test to confirm that the group FSQ module can be instantiated and run,
and that forward produces the same result as encode-decode.
"""
# Test inference with group FSQ
# instantiate
gfsq = GroupFiniteScalarQuantizer(num_groups=num_groups, num_levels_per_group=num_levels_per_group)
for i in range(self.num_examples):
inputs = torch.randn([self.batch_size, gfsq.codebook_dim, self.max_len])
input_len = torch.tensor([self.max_len] * self.batch_size, dtype=torch.int32)
# apply forward
dequantized_fw, indices_fw = gfsq(inputs=inputs, input_len=input_len)
# encode-decode
indices_enc = gfsq.encode(inputs=inputs, input_len=input_len)
dequantized_dec = gfsq.decode(indices=indices_enc, input_len=input_len)
# make sure the results are the same
torch.testing.assert_close(indices_enc, indices_fw, msg=f'example {i}: indices mismatch')
torch.testing.assert_close(dequantized_dec, dequantized_fw, msg=f'example {i}: dequantized mismatch')
# apply individual FSQs and make sure the results are the same
inputs_grouped = inputs.chunk(num_groups, dim=1)
dequantized_fw_grouped = dequantized_fw.chunk(num_groups, dim=1)
indices_fw_grouped = indices_fw.chunk(num_groups, dim=0)
for g in range(num_groups):
dequantized, indices = gfsq.fsqs[g](inputs=inputs_grouped[g], input_len=input_len)
torch.testing.assert_close(
dequantized, dequantized_fw_grouped[g], msg=f'example {i}: dequantized mismatch for group {g}'
)
torch.testing.assert_close(
indices, indices_fw_grouped[g], msg=f'example {i}: indices mismatch for group {g}'
)