anime-tts / app.py
Azul Alysum
Remove unecessary functions
99ed4b2
import argparse
import json
import os
import re
import tempfile
from pathlib import Path
import librosa
import numpy as np
import torch
from torch import no_grad, LongTensor
import commons
import utils
import gradio as gr
import gradio.utils as gr_utils
from gradio_client import utils as client_utils
import gradio.processing_utils as gr_processing_utils
from models import SynthesizerTrn
from text import text_to_sequence, _clean_text
from mel_processing import spectrogram_torch
limitation = os.getenv("SYSTEM") == "spaces" # limit text and audio length in huggingface spaces
audio_postprocess_ori = gr.Audio.postprocess
def audio_postprocess(self, y):
data = audio_postprocess_ori(self, y)
if data is None:
return None
try:
return gr_processing_utils.encode_url_or_file_to_base64(data["name"])
except:
return client_utils.encode_url_or_file_to_base64(data["name"])
gr.Audio.postprocess = audio_postprocess
def get_text(text, hps, is_symbol):
text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners)
if hps.data.add_blank:
text_norm = commons.intersperse(text_norm, 0)
text_norm = LongTensor(text_norm)
return text_norm
def tts_fn(text, speaker_id, speed, is_symbol):
if limitation:
text_len = len(re.sub("\[([A-Z]{2})\]", "", text))
max_len = 150
if is_symbol:
max_len *= 3
if text_len > max_len:
return "Error: Text is too long", None
stn_tst = get_text(text, hps, is_symbol)
with no_grad():
x_tst = stn_tst.unsqueeze(0).to(device)
x_tst_lengths = LongTensor([stn_tst.size(0)]).to(device)
sid = LongTensor([speaker_id]).to(device)
audio = model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8,
length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy()
del stn_tst, x_tst, x_tst_lengths, sid
return "Success", (hps.data.sampling_rate, audio)
def create_to_symbol_fn(hps):
def to_symbol_fn(is_symbol_input, input_text, temp_text):
return (_clean_text(input_text, hps.data.text_cleaners), input_text) if is_symbol_input \
else (temp_text, temp_text)
return to_symbol_fn
download_audio_js = """
() =>{{
let root = document.querySelector("body > gradio-app");
if (root.shadowRoot != null)
root = root.shadowRoot;
let audio = root.querySelector("#{audio_id}").querySelector("audio");
if (audio == undefined)
return;
audio = audio.src;
let oA = document.createElement("a");
oA.download = Math.floor(Math.random()*100000000)+'.wav';
oA.href = audio;
document.body.appendChild(oA);
oA.click();
oA.remove();
}}
"""
if __name__ == '__main__':
global speaker_ids, speakers
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cpu')
parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
args = parser.parse_args()
device = torch.device(args.device)
models_tts = []
models_soft_vc = []
with open("saved_model/info.json", "r", encoding="utf-8") as f:
models_info = json.load(f)
info = models_info['0']
name = info["title"]
author = info["author"]
example = info["example"]
config_path = f"saved_model/0/config.json"
model_path = f"saved_model/0/model.pth"
hps = utils.get_hparams_from_file(config_path)
model = SynthesizerTrn(
len(hps.symbols),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model)
utils.load_checkpoint(model_path, model, None)
model.eval().to(device)
speaker_ids = [sid for sid, name in enumerate(hps.speakers) if name != "None"]
speakers = [name for sid, name in enumerate(hps.speakers) if name != "None"]
models_tts.append((name, author
, speakers, example, hps.symbols, create_to_symbol_fn(hps)))
hubert = torch.hub.load("bshall/hubert:main", "hubert_soft", trust_repo=True).to(device)
app = gr.Blocks()
with app:
gr.Markdown("# Moe TTS And Voice Conversion Using VITS Model\n\n")
with gr.Tabs():
name, author, speakers, example, symbols, to_symbol_fn = models_tts[0]
with gr.Tab("Model"):
with gr.Column():
gr.Markdown(f"## {name}\n\n"
f"Model Author: {author}\n\n")
tts_input1 = gr.TextArea(label="Text (150 words limitation)", value=f"[JA]{example}[JA]",
elem_id=f"tts-input0")
tts_input2 = gr.Number(label="Speaker ID (check next tab)", value=0, precision=0)
tts_input3 = gr.Slider(label="Speed", value=1, minimum=0.5, maximum=2, step=0.1)
with gr.Accordion(label="Advanced Options", open=False):
temp_text_var = gr.Variable()
symbol_input = gr.Checkbox(value=False, label="Symbol input")
symbol_list = gr.Dataset(label="Symbol list", components=[tts_input1],
samples=[[x] for x in symbols],
elem_id=f"symbol-list0")
symbol_list_json = gr.Json(value=symbols, visible=False)
tts_submit = gr.Button("Generate", variant="primary")
tts_test = gr.Button("Test", variant="primary")
tts_output1 = gr.Textbox(label="Output Message")
tts_output2 = gr.Audio(label="Output Audio", elem_id=f"tts-audio0")
download = gr.Button("Download Audio")
download.click(None, [], [], _js=download_audio_js.format(audio_id=f"tts-audio0"))
tts_submit.click(tts_fn, [tts_input1, tts_input2, tts_input3, symbol_input],
[tts_output1, tts_output2])
tts_test.click(tts_fn, [tts_input1, tts_input2, tts_input3, symbol_input],
[tts_output1, tts_output2])
symbol_input.change(to_symbol_fn,
[symbol_input, tts_input1, temp_text_var],
[tts_input1, temp_text_var])
symbol_list.click(None, [symbol_list, symbol_list_json], [],
_js=f"""
(i,symbols) => {{
let root = document.querySelector("body > gradio-app");
if (root.shadowRoot != null)
root = root.shadowRoot;
let text_input = root.querySelector("#tts-input0").querySelector("textarea");
let startPos = text_input.selectionStart;
let endPos = text_input.selectionEnd;
let oldTxt = text_input.value;
let result = oldTxt.substring(0, startPos) + symbols[i] + oldTxt.substring(endPos);
text_input.value = result;
let x = window.scrollX, y = window.scrollY;
text_input.focus();
text_input.selectionStart = startPos + symbols[i].length;
text_input.selectionEnd = startPos + symbols[i].length;
text_input.blur();
window.scrollTo(x, y);
return [];
}}""")
with gr.Tab("Voices"):
gr.Markdown("## List of speakers and their IDs\n\n")
with gr.Column():
for index, speaker in enumerate(speakers):
gr.Markdown(f" {index}: {speaker}\n")
gr.Markdown(
"Model official repo \n\n"
"- [https://github.com/CjangCjengh/MoeGoe](https://github.com/CjangCjengh/MoeGoe)\n"
)
app.queue(concurrency_count=3).launch(show_api=True, share=args.share)