| import numpy as np |
| import torch |
|
|
| from TTS.vocoder.models.melgan_discriminator import MelganDiscriminator |
| from TTS.vocoder.models.melgan_multiscale_discriminator import MelganMultiscaleDiscriminator |
|
|
|
|
| def test_melgan_discriminator(): |
| model = MelganDiscriminator() |
| print(model) |
| dummy_input = torch.rand((4, 1, 256 * 10)) |
| output, _ = model(dummy_input) |
| assert np.all(output.shape == (4, 1, 10)) |
|
|
|
|
| def test_melgan_multi_scale_discriminator(): |
| model = MelganMultiscaleDiscriminator() |
| print(model) |
| dummy_input = torch.rand((4, 1, 256 * 16)) |
| scores, feats = model(dummy_input) |
| assert len(scores) == 3 |
| assert len(scores) == len(feats) |
| assert np.all(scores[0].shape == (4, 1, 64)) |
| assert np.all(feats[0][0].shape == (4, 16, 4096)) |
| assert np.all(feats[0][1].shape == (4, 64, 1024)) |
| assert np.all(feats[0][2].shape == (4, 256, 256)) |
|
|