Yamade / app.py
Xenobd's picture
Update app.py
a7553c1 verified
# app_simple.py
import argparse
import re
import os
import librosa
import numpy as np
import torch
from torch import no_grad, LongTensor
import gradio as gr
import ONNXVITS_infer
import utils
import commons
from text import text_to_sequence
from mel_processing import spectrogram_torch
language_marks = {
"Japanese": "",
"日本語": "[JA]",
"简体中文": "[ZH]",
"English": "[EN]",
"Mix": "",
}
def get_text(text, hps, is_symbol):
text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners)
if hps.data.add_blank:
text_norm = commons.intersperse(text_norm, 0)
return LongTensor(text_norm)
def tts_infer(text, speaker, language, speed=1.0, is_symbol=False):
if language is not None:
text = language_marks[language] + text + language_marks[language]
speaker_id = speaker_ids[speaker]
stn_tst = get_text(text, hps, is_symbol)
with no_grad():
x_tst = stn_tst.unsqueeze(0)
x_tst_lengths = LongTensor([stn_tst.size(0)])
sid = LongTensor([speaker_id])
audio = model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8,
length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy()
return (hps.data.sampling_rate, audio)
# Load your preferred model from list
model_path = "./pretrained_models/G_trilingual.pth"
config_path = "./configs/uma_trilingual.json"
onnx_dir = "./ONNX_net/G_trilingual/"
hps = utils.get_hparams_from_file(config_path)
model = ONNXVITS_infer.SynthesizerTrn(
len(hps.symbols),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
ONNX_dir=onnx_dir,
**hps.model)
utils.load_checkpoint(model_path, model, None)
model.eval()
speaker_ids = hps.speakers
speakers = list(hps.speakers.keys())
languages = ['日本語', '简体中文', 'English', 'Mix']
# Build minimal Gradio interface
iface = gr.Interface(
fn=tts_infer,
inputs=[
gr.Textbox(label="Input Text"),
gr.Dropdown(choices=speakers, label="Speaker"),
gr.Dropdown(choices=languages, label="Language"),
gr.Slider(minimum=0.1, maximum=5, value=1, label="Speed"),
gr.Checkbox(label="Symbol Input")
],
outputs=gr.Audio(label="Generated Audio"),
title="Simple Anime TTS API",
allow_flagging="never"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--share", action="store_true", default=False)
args = parser.parse_args()
iface.launch(share=args.share)