| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import os |
| | import tempfile |
| |
|
| | import pytest |
| | import torch |
| | from omegaconf import DictConfig |
| |
|
| | from nemo.collections.tts.models import WaveGlowModel |
| | from nemo.core.classes import typecheck |
| |
|
| | mcfg = DictConfig( |
| | { |
| | "_target_": "nemo.collections.tts.modules.waveglow.WaveGlowModule", |
| | "n_flows": 12, |
| | "n_group": 8, |
| | "n_mel_channels": 80, |
| | "n_early_every": 4, |
| | "n_early_size": 2, |
| | "n_wn_channels": 512, |
| | "n_wn_layers": 8, |
| | "wn_kernel_size": 3, |
| | } |
| | ) |
| |
|
| | pcfg = DictConfig( |
| | { |
| | "_target_": "nemo.collections.asr.parts.preprocessing.features.FilterbankFeatures", |
| | "dither": 0.0, |
| | "nfilt": 80, |
| | "stft_conv": False, |
| | } |
| | ) |
| |
|
| | wcfg = DictConfig({"waveglow": mcfg, "sigma": 1.0, "preprocessor": pcfg,}) |
| |
|
| |
|
| | def input_example(sz): |
| | mel = torch.randn(1, 1, 80, sz).cuda().half() |
| | z = torch.randn(1, 8, sz * 256 // 8, 1).cuda().half() |
| | return ( |
| | mel, |
| | z, |
| | ) |
| |
|
| |
|
| | def taco2wg(spec, z): |
| | spec = spec.permute(0, 3, 2, 1).contiguous() |
| | return spec.view(spec.size(0), spec.size(1), -1), z.view(z.size(0), z.size(1), -1) |
| |
|
| |
|
| | |
| | def forward_wrapper(self, spec, z=None): |
| | spec, z = taco2wg(spec, z) |
| | audio = self.waveglow.norm_dist_to_audio(spec=spec, sigma=1.0, z=z) |
| | return audio |
| |
|
| |
|
| | class TestWaveGlow: |
| | @pytest.mark.pleasefixme |
| | @pytest.mark.run_only_on('GPU') |
| | @pytest.mark.unit |
| | def test_export_to_onnx(self): |
| | model = WaveGlowModel(wcfg) |
| | model = model.cuda().half() |
| | typecheck.set_typecheck_enabled(enabled=False) |
| | with tempfile.TemporaryDirectory() as tmpdir, model.nemo_infer(): |
| | tmp_file_name = os.path.join(tmpdir, "waveglow.onnx") |
| |
|
| | n_mels = 80 |
| | |
| | inp = input_example(n_mels) |
| | inp1 = taco2wg(*inp) |
| | inp2 = inp1 |
| | res1 = model.waveglow(*inp1) |
| | res2 = model.waveglow(*inp2) |
| | assert torch.allclose(res1, res2, rtol=0.01, atol=0.1) |
| | WaveGlowModel.forward_for_export = forward_wrapper |
| | model.export( |
| | tmp_file_name, input_example=inp, verbose=False, check_trace=False, do_constant_folding=True, |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | t = TestWaveGlow() |
| | t.test_export_to_onnx() |
| |
|