|
|
import torch
|
|
|
from models.fatchord_version import WaveRNN
|
|
|
from utils import hparams as hp
|
|
|
from utils.text.symbols import symbols
|
|
|
from models.tacotron import Tacotron
|
|
|
import argparse
|
|
|
from utils.text import text_to_sequence
|
|
|
from utils.display import save_attention, simple_table
|
|
|
import zipfile, os
|
|
|
|
|
|
|
|
|
os.makedirs('quick_start/tts_weights/', exist_ok=True)
|
|
|
os.makedirs('quick_start/voc_weights/', exist_ok=True)
|
|
|
|
|
|
zip_ref = zipfile.ZipFile('pretrained/ljspeech.wavernn.mol.800k.zip', 'r')
|
|
|
zip_ref.extractall('quick_start/voc_weights/')
|
|
|
zip_ref.close()
|
|
|
|
|
|
zip_ref = zipfile.ZipFile('pretrained/ljspeech.tacotron.r2.180k.zip', 'r')
|
|
|
zip_ref.extractall('quick_start/tts_weights/')
|
|
|
zip_ref.close()
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='TTS Generator')
|
|
|
parser.add_argument('--input_text', '-i', type=str, help='[string] Type in something here and TTS will generate it!')
|
|
|
parser.add_argument('--batched', '-b', dest='batched', action='store_true', help='Fast Batched Generation (lower quality)')
|
|
|
parser.add_argument('--unbatched', '-u', dest='batched', action='store_false', help='Slower Unbatched Generation (better quality)')
|
|
|
parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment')
|
|
|
parser.add_argument('--hp_file', metavar='FILE', default='hparams.py',
|
|
|
help='The file to use for the hyperparameters')
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
hp.configure(args.hp_file)
|
|
|
|
|
|
parser.set_defaults(batched=True)
|
|
|
parser.set_defaults(input_text=None)
|
|
|
|
|
|
batched = args.batched
|
|
|
input_text = args.input_text
|
|
|
|
|
|
if not args.force_cpu and torch.cuda.is_available():
|
|
|
device = torch.device('cuda')
|
|
|
else:
|
|
|
device = torch.device('cpu')
|
|
|
print('Using device:', device)
|
|
|
|
|
|
print('\nInitialising WaveRNN Model...\n')
|
|
|
|
|
|
|
|
|
voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
|
|
|
fc_dims=hp.voc_fc_dims,
|
|
|
bits=hp.bits,
|
|
|
pad=hp.voc_pad,
|
|
|
upsample_factors=hp.voc_upsample_factors,
|
|
|
feat_dims=hp.num_mels,
|
|
|
compute_dims=hp.voc_compute_dims,
|
|
|
res_out_dims=hp.voc_res_out_dims,
|
|
|
res_blocks=hp.voc_res_blocks,
|
|
|
hop_length=hp.hop_length,
|
|
|
sample_rate=hp.sample_rate,
|
|
|
mode='MOL').to(device)
|
|
|
|
|
|
voc_model.load('quick_start/voc_weights/latest_weights.pyt')
|
|
|
|
|
|
print('\nInitialising Tacotron Model...\n')
|
|
|
|
|
|
|
|
|
tts_model = Tacotron(embed_dims=hp.tts_embed_dims,
|
|
|
num_chars=len(symbols),
|
|
|
encoder_dims=hp.tts_encoder_dims,
|
|
|
decoder_dims=hp.tts_decoder_dims,
|
|
|
n_mels=hp.num_mels,
|
|
|
fft_bins=hp.num_mels,
|
|
|
postnet_dims=hp.tts_postnet_dims,
|
|
|
encoder_K=hp.tts_encoder_K,
|
|
|
lstm_dims=hp.tts_lstm_dims,
|
|
|
postnet_K=hp.tts_postnet_K,
|
|
|
num_highways=hp.tts_num_highways,
|
|
|
dropout=hp.tts_dropout,
|
|
|
stop_threshold=hp.tts_stop_threshold).to(device)
|
|
|
|
|
|
|
|
|
tts_model.load('quick_start/tts_weights/latest_weights.pyt')
|
|
|
|
|
|
if input_text:
|
|
|
inputs = [text_to_sequence(input_text.strip(), hp.tts_cleaner_names)]
|
|
|
else:
|
|
|
with open('sentences.txt') as f:
|
|
|
inputs = [text_to_sequence(l.strip(), hp.tts_cleaner_names) for l in f]
|
|
|
|
|
|
voc_k = voc_model.get_step() // 1000
|
|
|
tts_k = tts_model.get_step() // 1000
|
|
|
|
|
|
r = tts_model.r
|
|
|
|
|
|
simple_table([('WaveRNN', str(voc_k) + 'k'),
|
|
|
(f'Tacotron(r={r})', str(tts_k) + 'k'),
|
|
|
('Generation Mode', 'Batched' if batched else 'Unbatched'),
|
|
|
('Target Samples', 11_000 if batched else 'N/A'),
|
|
|
('Overlap Samples', 550 if batched else 'N/A')])
|
|
|
|
|
|
for i, x in enumerate(inputs, 1):
|
|
|
|
|
|
print(f'\n| Generating {i}/{len(inputs)}')
|
|
|
_, m, attention = tts_model.generate(x)
|
|
|
|
|
|
if input_text:
|
|
|
save_path = f'quick_start/__input_{input_text[:10]}_{tts_k}k.wav'
|
|
|
else:
|
|
|
save_path = f'quick_start/{i}_batched{str(batched)}_{tts_k}k.wav'
|
|
|
|
|
|
|
|
|
|
|
|
m = torch.tensor(m).unsqueeze(0)
|
|
|
m = (m + 4) / 8
|
|
|
|
|
|
voc_model.generate(m, save_path, batched, 11_000, 550, hp.mu_law)
|
|
|
|
|
|
print('\n\nDone.\n')
|
|
|
|