| |
| |
| |
| |
| |
|
|
| import random |
|
|
| import torch |
|
|
| from audiocraft.adversarial.discriminators import ( |
| MultiPeriodDiscriminator, |
| MultiScaleDiscriminator, |
| MultiScaleSTFTDiscriminator |
| ) |
|
|
|
|
| class TestMultiPeriodDiscriminator: |
|
|
| def test_mpd_discriminator(self): |
| N, C, T = 2, 2, random.randrange(1, 100_000) |
| t0 = torch.randn(N, C, T) |
| periods = [1, 2, 3] |
| mpd = MultiPeriodDiscriminator(periods=periods, in_channels=C) |
| logits, fmaps = mpd(t0) |
|
|
| assert len(logits) == len(periods) |
| assert len(fmaps) == len(periods) |
| assert all([logit.shape[0] == N and len(logit.shape) == 4 for logit in logits]) |
| assert all([feature.shape[0] == N for fmap in fmaps for feature in fmap]) |
|
|
|
|
| class TestMultiScaleDiscriminator: |
|
|
| def test_msd_discriminator(self): |
| N, C, T = 2, 2, random.randrange(1, 100_000) |
| t0 = torch.randn(N, C, T) |
|
|
| scale_norms = ['weight_norm', 'weight_norm'] |
| msd = MultiScaleDiscriminator(scale_norms=scale_norms, in_channels=C) |
| logits, fmaps = msd(t0) |
|
|
| assert len(logits) == len(scale_norms) |
| assert len(fmaps) == len(scale_norms) |
| assert all([logit.shape[0] == N and len(logit.shape) == 3 for logit in logits]) |
| assert all([feature.shape[0] == N for fmap in fmaps for feature in fmap]) |
|
|
|
|
| class TestMultiScaleStftDiscriminator: |
|
|
| def test_msstftd_discriminator(self): |
| N, C, T = 2, 2, random.randrange(1, 100_000) |
| t0 = torch.randn(N, C, T) |
|
|
| n_filters = 4 |
| n_ffts = [128, 256, 64] |
| hop_lengths = [32, 64, 16] |
| win_lengths = [128, 256, 64] |
|
|
| msstftd = MultiScaleSTFTDiscriminator(filters=n_filters, n_ffts=n_ffts, hop_lengths=hop_lengths, |
| win_lengths=win_lengths, in_channels=C) |
| logits, fmaps = msstftd(t0) |
|
|
| assert len(logits) == len(n_ffts) |
| assert len(fmaps) == len(n_ffts) |
| assert all([logit.shape[0] == N and len(logit.shape) == 4 for logit in logits]) |
| assert all([feature.shape[0] == N for fmap in fmaps for feature in fmap]) |
|
|