| | """Search a good noise schedule for WaveGrad for a given number of inference iterations""" |
| | import argparse |
| | from itertools import product as cartesian_product |
| |
|
| | import numpy as np |
| | import torch |
| | from torch.utils.data import DataLoader |
| | from tqdm import tqdm |
| |
|
| | from TTS.config import load_config |
| | from TTS.utils.audio import AudioProcessor |
| | from TTS.vocoder.datasets.preprocess import load_wav_data |
| | from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset |
| | from TTS.vocoder.models import setup_model |
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--model_path", type=str, help="Path to model checkpoint.") |
| | parser.add_argument("--config_path", type=str, help="Path to model config file.") |
| | parser.add_argument("--data_path", type=str, help="Path to data directory.") |
| | parser.add_argument("--output_path", type=str, help="path for output file including file name and extension.") |
| | parser.add_argument( |
| | "--num_iter", |
| | type=int, |
| | help="Number of model inference iterations that you like to optimize noise schedule for.", |
| | ) |
| | parser.add_argument("--use_cuda", action="store_true", help="enable CUDA.") |
| | parser.add_argument("--num_samples", type=int, default=1, help="Number of datasamples used for inference.") |
| | parser.add_argument( |
| | "--search_depth", |
| | type=int, |
| | default=3, |
| | help="Search granularity. Increasing this increases the run-time exponentially.", |
| | ) |
| |
|
| | |
| | args = parser.parse_args() |
| | config = load_config(args.config_path) |
| |
|
| | |
| | ap = AudioProcessor(**config.audio) |
| |
|
| | |
| | _, train_data = load_wav_data(args.data_path, 0) |
| | train_data = train_data[: args.num_samples] |
| | dataset = WaveGradDataset( |
| | ap=ap, |
| | items=train_data, |
| | seq_len=-1, |
| | hop_len=ap.hop_length, |
| | pad_short=config.pad_short, |
| | conv_pad=config.conv_pad, |
| | is_training=True, |
| | return_segments=False, |
| | use_noise_augment=False, |
| | use_cache=False, |
| | verbose=True, |
| | ) |
| | loader = DataLoader( |
| | dataset, |
| | batch_size=1, |
| | shuffle=False, |
| | collate_fn=dataset.collate_full_clips, |
| | drop_last=False, |
| | num_workers=config.num_loader_workers, |
| | pin_memory=False, |
| | ) |
| |
|
| | |
| | model = setup_model(config) |
| | if args.use_cuda: |
| | model.cuda() |
| |
|
| | |
| | base_values = sorted(10 * np.random.uniform(size=args.search_depth)) |
| | print(f" > base values: {base_values}") |
| | exponents = 10 ** np.linspace(-6, -1, num=args.num_iter) |
| | best_error = float("inf") |
| | best_schedule = None |
| | total_search_iter = len(base_values) ** args.num_iter |
| | for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter): |
| | beta = exponents * base |
| | model.compute_noise_level(beta) |
| | for data in loader: |
| | mel, audio = data |
| | y_hat = model.inference(mel.cuda() if args.use_cuda else mel) |
| |
|
| | if args.use_cuda: |
| | y_hat = y_hat.cpu() |
| | y_hat = y_hat.numpy() |
| |
|
| | mel_hat = [] |
| | for i in range(y_hat.shape[0]): |
| | m = ap.melspectrogram(y_hat[i, 0])[:, :-1] |
| | mel_hat.append(torch.from_numpy(m)) |
| |
|
| | mel_hat = torch.stack(mel_hat) |
| | mse = torch.sum((mel - mel_hat) ** 2).mean() |
| | if mse.item() < best_error: |
| | best_error = mse.item() |
| | best_schedule = {"beta": beta} |
| | print(f" > Found a better schedule. - MSE: {mse.item()}") |
| | np.save(args.output_path, best_schedule) |
| |
|