| | |
| | |
| | |
| | |
| | |
| |
|
| | import random |
| |
|
| | import torch |
| |
|
| | from audiocraft.adversarial.discriminators import ( |
| | MultiPeriodDiscriminator, |
| | MultiScaleDiscriminator, |
| | MultiScaleSTFTDiscriminator |
| | ) |
| |
|
| |
|
| | class TestMultiPeriodDiscriminator: |
| |
|
| | def test_mpd_discriminator(self): |
| | N, C, T = 2, 2, random.randrange(1, 100_000) |
| | t0 = torch.randn(N, C, T) |
| | periods = [1, 2, 3] |
| | mpd = MultiPeriodDiscriminator(periods=periods, in_channels=C) |
| | logits, fmaps = mpd(t0) |
| |
|
| | assert len(logits) == len(periods) |
| | assert len(fmaps) == len(periods) |
| | assert all([logit.shape[0] == N and len(logit.shape) == 4 for logit in logits]) |
| | assert all([feature.shape[0] == N for fmap in fmaps for feature in fmap]) |
| |
|
| |
|
| | class TestMultiScaleDiscriminator: |
| |
|
| | def test_msd_discriminator(self): |
| | N, C, T = 2, 2, random.randrange(1, 100_000) |
| | t0 = torch.randn(N, C, T) |
| |
|
| | scale_norms = ['weight_norm', 'weight_norm'] |
| | msd = MultiScaleDiscriminator(scale_norms=scale_norms, in_channels=C) |
| | logits, fmaps = msd(t0) |
| |
|
| | assert len(logits) == len(scale_norms) |
| | assert len(fmaps) == len(scale_norms) |
| | assert all([logit.shape[0] == N and len(logit.shape) == 3 for logit in logits]) |
| | assert all([feature.shape[0] == N for fmap in fmaps for feature in fmap]) |
| |
|
| |
|
| | class TestMultiScaleStftDiscriminator: |
| |
|
| | def test_msstftd_discriminator(self): |
| | N, C, T = 2, 2, random.randrange(1, 100_000) |
| | t0 = torch.randn(N, C, T) |
| |
|
| | n_filters = 4 |
| | n_ffts = [128, 256, 64] |
| | hop_lengths = [32, 64, 16] |
| | win_lengths = [128, 256, 64] |
| |
|
| | msstftd = MultiScaleSTFTDiscriminator(filters=n_filters, n_ffts=n_ffts, hop_lengths=hop_lengths, |
| | win_lengths=win_lengths, in_channels=C) |
| | logits, fmaps = msstftd(t0) |
| |
|
| | assert len(logits) == len(n_ffts) |
| | assert len(fmaps) == len(n_ffts) |
| | assert all([logit.shape[0] == N and len(logit.shape) == 4 for logit in logits]) |
| | assert all([feature.shape[0] == N for fmap in fmaps for feature in fmap]) |
| |
|