| import numpy as np |
| import torch |
| import random |
| from TTS.vocoder.models.wavernn import WaveRNN |
|
|
|
|
| def test_wavernn(): |
| model = WaveRNN( |
| rnn_dims=512, |
| fc_dims=512, |
| mode=10, |
| mulaw=False, |
| pad=2, |
| use_aux_net=True, |
| use_upsample_net=True, |
| upsample_factors=[4, 8, 8], |
| feat_dims=80, |
| compute_dims=128, |
| res_out_dims=128, |
| num_res_blocks=10, |
| hop_length=256, |
| sample_rate=22050, |
| ) |
| dummy_x = torch.rand((2, 1280)) |
| dummy_m = torch.rand((2, 80, 9)) |
| y_size = random.randrange(20, 60) |
| dummy_y = torch.rand((80, y_size)) |
| output = model(dummy_x, dummy_m) |
| assert np.all(output.shape == (2, 1280, 4 * 256)), output.shape |
| output = model.inference(dummy_y, True, 5500, 550) |
| assert np.all(output.shape == (256 * (y_size - 1),)) |
|
|