| | from typing import Dict |
| |
|
| | import numpy as np |
| | import torch |
| | from matplotlib import pyplot as plt |
| |
|
| | from TTS.tts.utils.visual import plot_spectrogram |
| | from TTS.utils.audio import AudioProcessor |
| |
|
| |
|
| | def interpolate_vocoder_input(scale_factor, spec): |
| | """Interpolate spectrogram by the scale factor. |
| | It is mainly used to match the sampling rates of |
| | the tts and vocoder models. |
| | |
| | Args: |
| | scale_factor (float): scale factor to interpolate the spectrogram |
| | spec (np.array): spectrogram to be interpolated |
| | |
| | Returns: |
| | torch.tensor: interpolated spectrogram. |
| | """ |
| | print(" > before interpolation :", spec.shape) |
| | spec = torch.tensor(spec).unsqueeze(0).unsqueeze(0) |
| | spec = torch.nn.functional.interpolate( |
| | spec, scale_factor=scale_factor, recompute_scale_factor=True, mode="bilinear", align_corners=False |
| | ).squeeze(0) |
| | print(" > after interpolation :", spec.shape) |
| | return spec |
| |
|
| |
|
| | def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_prefix: str = None) -> Dict: |
| | """Plot the predicted and the real waveform and their spectrograms. |
| | |
| | Args: |
| | y_hat (torch.tensor): Predicted waveform. |
| | y (torch.tensor): Real waveform. |
| | ap (AudioProcessor): Audio processor used to process the waveform. |
| | name_prefix (str, optional): Name prefix used to name the figures. Defaults to None. |
| | |
| | Returns: |
| | Dict: output figures keyed by the name of the figures. |
| | """ """Plot vocoder model results""" |
| | if name_prefix is None: |
| | name_prefix = "" |
| |
|
| | |
| | y_hat = y_hat[0].squeeze().detach().cpu().numpy() |
| | y = y[0].squeeze().detach().cpu().numpy() |
| |
|
| | spec_fake = ap.melspectrogram(y_hat).T |
| | spec_real = ap.melspectrogram(y).T |
| | spec_diff = np.abs(spec_fake - spec_real) |
| |
|
| | |
| | fig_wave = plt.figure() |
| | plt.subplot(2, 1, 1) |
| | plt.plot(y) |
| | plt.title("groundtruth speech") |
| | plt.subplot(2, 1, 2) |
| | plt.plot(y_hat) |
| | plt.title("generated speech") |
| | plt.tight_layout() |
| | plt.close() |
| |
|
| | figures = { |
| | name_prefix + "spectrogram/fake": plot_spectrogram(spec_fake), |
| | name_prefix + "spectrogram/real": plot_spectrogram(spec_real), |
| | name_prefix + "spectrogram/diff": plot_spectrogram(spec_diff), |
| | name_prefix + "speech_comparison": fig_wave, |
| | } |
| | return figures |
| |
|