| import numpy as np |
| import os |
| import torch |
| import commons |
|
|
| import models |
| import utils |
| from argparse import ArgumentParser |
| from tqdm import tqdm |
| from text import text_to_sequence |
|
|
| if __name__ == "__main__": |
| parser = ArgumentParser() |
| parser.add_argument("-m", "--model_dir", required=True, type=str) |
| parser.add_argument("-s", "--mels_dir", required=True, type=str) |
| args = parser.parse_args() |
| MODEL_DIR = args.model_dir |
| SAVE_MELS_DIR = args.mels_dir |
|
|
| if not os.path.exists(SAVE_MELS_DIR): |
| os.makedirs(SAVE_MELS_DIR) |
|
|
| hps = utils.get_hparams_from_dir(MODEL_DIR) |
| symbols = list(hps.data.punc) + list(hps.data.chars) |
| checkpoint_path = utils.latest_checkpoint_path(MODEL_DIR) |
| cleaner = hps.data.text_cleaners |
|
|
| model = models.FlowGenerator( |
| len(symbols) + getattr(hps.data, "add_blank", False), |
| out_channels=hps.data.n_mel_channels, |
| **hps.model |
| ).to("cuda") |
|
|
| utils.load_checkpoint(checkpoint_path, model) |
| model.decoder.store_inverse() |
| _ = model.eval() |
|
|
| def get_mel(text, fpath): |
| if getattr(hps.data, "add_blank", False): |
| text_norm = text_to_sequence(text, symbols, cleaner) |
| text_norm = commons.intersperse(text_norm, len(symbols)) |
| else: |
| text = " " + text.strip() + " " |
| text_norm = text_to_sequence(text, symbols, cleaner) |
|
|
| sequence = np.array(text_norm)[None, :] |
|
|
| x_tst = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long() |
| x_tst_lengths = torch.tensor([x_tst.shape[1]]).cuda() |
|
|
| with torch.no_grad(): |
| noise_scale = 0.667 |
| length_scale = 1.0 |
| (y_gen_tst, *_), *_, (attn_gen, *_) = model( |
| x_tst, |
| x_tst_lengths, |
| gen=True, |
| noise_scale=noise_scale, |
| length_scale=length_scale, |
| ) |
|
|
| np.save(os.path.join(SAVE_MELS_DIR, fpath), y_gen_tst.cpu().detach().numpy()) |
|
|
| for f in [hps.data.training_files, hps.data.validation_files]: |
| file_lines = open(f).read().splitlines() |
|
|
| for line in tqdm(file_lines): |
| fname, text = line.split("|") |
| fname = os.path.basename(fname).replace(".wav", ".npy") |
| get_mel(text, fname) |
|
|