| import os |
|
|
| import torch |
|
|
| from tests import get_tests_input_path, get_tests_output_path, get_tests_path |
| from TTS.config import BaseAudioConfig |
| from TTS.utils.audio import AudioProcessor |
| from TTS.utils.audio.numpy_transforms import stft |
| from TTS.vocoder.layers.losses import MelganFeatureLoss, MultiScaleSTFTLoss, STFTLoss, TorchSTFT |
|
|
| TESTS_PATH = get_tests_path() |
|
|
| OUT_PATH = os.path.join(get_tests_output_path(), "audio_tests") |
| os.makedirs(OUT_PATH, exist_ok=True) |
|
|
| WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav") |
|
|
| ap = AudioProcessor(**BaseAudioConfig().to_dict()) |
|
|
|
|
| def test_torch_stft(): |
| torch_stft = TorchSTFT(ap.fft_size, ap.hop_length, ap.win_length) |
| |
| wav = ap.load_wav(WAV_FILE) |
| M_librosa = abs(stft(y=wav, fft_size=ap.fft_size, hop_length=ap.hop_length, win_length=ap.win_length)) |
| |
| wav = torch.from_numpy(wav[None, :]).float() |
| M_torch = torch_stft(wav) |
| |
| assert (M_librosa - M_torch[0].data.numpy()).max() < 1e-5 |
|
|
|
|
| def test_stft_loss(): |
| stft_loss = STFTLoss(ap.fft_size, ap.hop_length, ap.win_length) |
| wav = ap.load_wav(WAV_FILE) |
| wav = torch.from_numpy(wav[None, :]).float() |
| loss_m, loss_sc = stft_loss(wav, wav) |
| assert loss_m + loss_sc == 0 |
| loss_m, loss_sc = stft_loss(wav, torch.rand_like(wav)) |
| assert loss_sc < 1.0 |
| assert loss_m + loss_sc > 0 |
|
|
|
|
| def test_multiscale_stft_loss(): |
| stft_loss = MultiScaleSTFTLoss( |
| [ap.fft_size // 2, ap.fft_size, ap.fft_size * 2], |
| [ap.hop_length // 2, ap.hop_length, ap.hop_length * 2], |
| [ap.win_length // 2, ap.win_length, ap.win_length * 2], |
| ) |
| wav = ap.load_wav(WAV_FILE) |
| wav = torch.from_numpy(wav[None, :]).float() |
| loss_m, loss_sc = stft_loss(wav, wav) |
| assert loss_m + loss_sc == 0 |
| loss_m, loss_sc = stft_loss(wav, torch.rand_like(wav)) |
| assert loss_sc < 1.0 |
| assert loss_m + loss_sc > 0 |
|
|
|
|
| def test_melgan_feature_loss(): |
| feats_real = [] |
| feats_fake = [] |
|
|
| |
| for _ in range(5): |
| scale_feats_real = [] |
| scale_feats_fake = [] |
| for _ in range(4): |
| scale_feats_real.append(torch.rand([3, 5, 7])) |
| scale_feats_fake.append(torch.rand([3, 5, 7])) |
| feats_real.append(scale_feats_real) |
| feats_fake.append(scale_feats_fake) |
|
|
| loss_func = MelganFeatureLoss() |
| loss = loss_func(feats_fake, feats_real) |
| assert loss.item() <= 1.0 |
|
|
| feats_real = [] |
| feats_fake = [] |
|
|
| |
| for _ in range(5): |
| scale_feats_real = [] |
| scale_feats_fake = [] |
| for _ in range(4): |
| tensor = torch.rand([3, 5, 7]) |
| scale_feats_real.append(tensor) |
| scale_feats_fake.append(tensor) |
| feats_real.append(scale_feats_real) |
| feats_fake.append(scale_feats_fake) |
|
|
| loss_func = MelganFeatureLoss() |
| loss = loss_func(feats_fake, feats_real) |
| assert loss.item() == 0 |
|
|