EducaTTS / app.py
chnk58hoang
change examples
b3ab037
import torch # isort:skip
torch.manual_seed(42)
import json
import re
import unicodedata
from types import SimpleNamespace
import numpy as np
import regex
import gradio as gr
from models import Vocoder, SynthesizerTrn
config_file = "config.json"
vocoder_model_path = "melgan.pth"
tts_model_path = "best_model_363794.pth"
phone_set_file = "vbx_phone_set.json"
device = "cuda" if torch.cuda.is_available() else "cpu"
with open(config_file, "rb") as f:
hps = json.load(f, object_hook=lambda x: SimpleNamespace(**x))
# load phone set json file
with open(phone_set_file, "r") as f:
phone_set = json.load(f)
assert phone_set[0][1:-1] == "SEP"
assert "sil" in phone_set
sil_idx = phone_set.index("sil")
space_re = regex.compile(r"\s+")
number_re = regex.compile("([0-9]+)")
digits = ["không", "một", "hai", "ba", "bốn", "năm", "sáu", "bảy", "tám", "chín"]
num_re = regex.compile(r"([0-9.,]*[0-9])")
alphabet = "aàáảãạăằắẳẵặâầấẩẫậeèéẻẽẹêềếểễệiìíỉĩịoòóỏõọôồốổỗộơờớởỡợuùúủũụưừứửữựyỳýỷỹỵbcdđghklmnpqrstvx"
keep_text_and_num_re = regex.compile(rf"[^\s{alphabet}.,0-9]")
keep_text_re = regex.compile(rf"[^\s{alphabet}]")
def read_number(num: str) -> str:
if len(num) == 1:
return digits[int(num)]
elif len(num) == 2 and num.isdigit():
n = int(num)
end = digits[n % 10]
if n == 10:
return "mười"
if n % 10 == 5:
end = "lăm"
if n % 10 == 0:
return digits[n // 10] + " mươi"
elif n < 20:
return "mười " + end
else:
if n % 10 == 1:
end = "mốt"
return digits[n // 10] + " mươi " + end
elif len(num) == 3 and num.isdigit():
n = int(num)
if n % 100 == 0:
return digits[n // 100] + " trăm"
elif num[1] == "0":
return digits[n // 100] + " trăm lẻ " + digits[n % 100]
else:
return digits[n // 100] + " trăm " + read_number(num[1:])
elif len(num) >= 4 and len(num) <= 6 and num.isdigit():
n = int(num)
n1 = n // 1000
return read_number(str(n1)) + " ngàn " + read_number(num[-3:])
elif "," in num:
n1, n2 = num.split(",")
return read_number(n1) + " phẩy " + read_number(n2)
elif "." in num:
parts = num.split(".")
if len(parts) == 2:
if parts[1] == "000":
return read_number(parts[0]) + " ngàn"
elif parts[1].startswith("00"):
end = digits[int(parts[1][2:])]
return read_number(parts[0]) + " ngàn lẻ " + end
else:
return read_number(parts[0]) + " ngàn " + read_number(parts[1])
elif len(parts) == 3:
return (
read_number(parts[0])
+ " triệu "
+ read_number(parts[1])
+ " ngàn "
+ read_number(parts[2])
)
return num
def text_to_phone_idx(text):
# lowercase
text = text.lower()
# unicode normalize
text = unicodedata.normalize("NFKC", text)
text = text.replace(".", " . ")
text = text.replace(",", " , ")
text = text.replace(";", " ; ")
text = text.replace(":", " : ")
text = text.replace("!", " ! ")
text = text.replace("?", " ? ")
text = text.replace("(", " ( ")
text = num_re.sub(r" \1 ", text)
words = text.split()
words = [read_number(w) if num_re.fullmatch(w) else w for w in words]
text = " ".join(words)
# remove redundant spaces
text = re.sub(r"\s+", " ", text)
# remove leading and trailing spaces
text = text.strip()
# convert words to phone indices
tokens = []
for c in text:
# if c is "," or ".", add <sil> phone
if c in ":,.!?;(":
tokens.append(sil_idx)
elif c in phone_set:
tokens.append(phone_set.index(c))
elif c == " ":
# add <sep> phone
tokens.append(0)
if tokens[0] != sil_idx:
# insert <sil> phone at the beginning
tokens = [sil_idx, 0] + tokens
if tokens[-1] != sil_idx:
tokens = tokens + [0, sil_idx]
return tokens
def text_to_speech(vocoder, generator, text):
# prevent too long text
if len(text) > 500:
text = text[:500]
phone_idx = text_to_phone_idx(text)
batch = {
"phone_idx": np.array([phone_idx]),
"phone_length": np.array([len(phone_idx)]),
}
# predict phoneme duration
phone_length = torch.from_numpy(batch["phone_length"].copy()).long().to(device)
phone_idx = torch.from_numpy(batch["phone_idx"].copy()).long().to(device)
with torch.inference_mode():
phone_duration = vocoder(phone_idx, phone_length)[:, :, 0] * 1000
phone_duration = torch.where(
phone_idx == sil_idx, torch.clamp_min(phone_duration, 200), phone_duration
)
phone_duration = torch.where(phone_idx == 0, 0, phone_duration)
# generate waveform
end_time = torch.cumsum(phone_duration, dim=-1)
start_time = end_time - phone_duration
start_frame = start_time / 1000 * hps.data.sampling_rate / hps.data.hop_length
end_frame = end_time / 1000 * hps.data.sampling_rate / hps.data.hop_length
spec_length = end_frame.max(dim=-1).values
pos = torch.arange(0, spec_length.item(), device=device)
attn = torch.logical_and(
pos[None, :, None] >= start_frame[:, None, :],
pos[None, :, None] < end_frame[:, None, :],
).float()
with torch.inference_mode():
y_hat = generator.infer(
phone_idx, phone_length, spec_length, attn, max_len=None, noise_scale=0.667
)[0]
wave = y_hat[0, 0].data.cpu().numpy()
return (wave * (2**15)).astype(np.int16)
def load_models():
vocoder = Vocoder(hps.data.vocab_size, 64, 4).to(device)
vocoder.load_state_dict(torch.load(vocoder_model_path, map_location=device))
vocoder = vocoder.eval()
generator = SynthesizerTrn(
hps.data.vocab_size,
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
**vars(hps.model),
).to(device)
del generator.enc_q
ckpt = torch.load(tts_model_path, map_location=device)
params = {}
for k, v in ckpt["net_g"].items():
k = k[7:] if k.startswith("module.") else k
params[k] = v
generator.load_state_dict(params, strict=False)
del ckpt, params
generator = generator.eval()
return vocoder, generator
def speak(text):
vocoder, generator = load_models()
paragraphs = text.split("\n")
clips = []
for paragraph in paragraphs:
paragraph = paragraph.strip()
if paragraph == "":
continue
clips.append(text_to_speech(vocoder, generator, paragraph))
y = np.concatenate(clips)
return hps.data.sampling_rate, y
title = 'Educa Text to Speech'
gr.Interface(
fn=speak,
inputs="text",
outputs="audio",
title=title,
examples=[
"Xin chào, đây là một thử nghiệm tổng hợp giọng nói tới từ công ty cổ phần giáo dục ê đu ca.",
],
theme="default",
allow_screenshot=False,
allow_flagging="never",
).launch(debug=False)