| |
| import argparse |
| import re |
| import os |
| import librosa |
| import numpy as np |
| import torch |
| from torch import no_grad, LongTensor |
| import gradio as gr |
| import ONNXVITS_infer |
| import utils |
| import commons |
| from text import text_to_sequence |
| from mel_processing import spectrogram_torch |
|
|
|
|
| language_marks = { |
| "Japanese": "", |
| "日本語": "[JA]", |
| "简体中文": "[ZH]", |
| "English": "[EN]", |
| "Mix": "", |
| } |
|
|
| def get_text(text, hps, is_symbol): |
| text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners) |
| if hps.data.add_blank: |
| text_norm = commons.intersperse(text_norm, 0) |
| return LongTensor(text_norm) |
|
|
| def tts_infer(text, speaker, language, speed=1.0, is_symbol=False): |
| if language is not None: |
| text = language_marks[language] + text + language_marks[language] |
| speaker_id = speaker_ids[speaker] |
| stn_tst = get_text(text, hps, is_symbol) |
| with no_grad(): |
| x_tst = stn_tst.unsqueeze(0) |
| x_tst_lengths = LongTensor([stn_tst.size(0)]) |
| sid = LongTensor([speaker_id]) |
| audio = model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8, |
| length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy() |
| return (hps.data.sampling_rate, audio) |
|
|
| |
| model_path = "./pretrained_models/G_trilingual.pth" |
| config_path = "./configs/uma_trilingual.json" |
| onnx_dir = "./ONNX_net/G_trilingual/" |
| hps = utils.get_hparams_from_file(config_path) |
| model = ONNXVITS_infer.SynthesizerTrn( |
| len(hps.symbols), |
| hps.data.filter_length // 2 + 1, |
| hps.train.segment_size // hps.data.hop_length, |
| n_speakers=hps.data.n_speakers, |
| ONNX_dir=onnx_dir, |
| **hps.model) |
| utils.load_checkpoint(model_path, model, None) |
| model.eval() |
| speaker_ids = hps.speakers |
| speakers = list(hps.speakers.keys()) |
| languages = ['日本語', '简体中文', 'English', 'Mix'] |
|
|
| |
| iface = gr.Interface( |
| fn=tts_infer, |
| inputs=[ |
| gr.Textbox(label="Input Text"), |
| gr.Dropdown(choices=speakers, label="Speaker"), |
| gr.Dropdown(choices=languages, label="Language"), |
| gr.Slider(minimum=0.1, maximum=5, value=1, label="Speed"), |
| gr.Checkbox(label="Symbol Input") |
| ], |
| outputs=gr.Audio(label="Generated Audio"), |
| title="Simple Anime TTS API", |
| allow_flagging="never" |
| ) |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--share", action="store_true", default=False) |
| args = parser.parse_args() |
| iface.launch(share=args.share) |
|
|