melotts_on_cpu / app.py
wolfofbackstreet's picture
init
0cc6d2d
from huggingface_hub import snapshot_download
import os
from pathlib import Path
import gradio as gr
import inspect
from typing import Callable, Any, get_type_hints, Tuple, Union
import numpy as np
import gradio as gr
from typing import Iterable, List, Tuple
import jieba3
import onnxruntime as ort
import soundfile as sf
import torch
import numpy as np
model_path = "zh_en_melotts"
# Define the local directory where you want to save the files
local_folder_path = Path(model_path)
# Create the directory if it doesn't exist
os.makedirs(local_folder_path, exist_ok=True)
# Download the repository snapshot to the specified local folder
snapshot_download(
repo_id="wolfofbackstreet/melotts_chinese_mix_english_onnx",
local_dir=local_folder_path,
local_dir_use_symlinks=False # Recommended to avoid symlinks if you want portable files
)
def parse_docstring(func):
doc = inspect.getdoc(func)
if not doc:
return {"title": "Untitled", "description": ""}
lines = doc.splitlines()
title = next((line.replace("Title:", "").strip() for line in lines if line.startswith("Title:")), "Untitled")
description = "\n".join(line.strip() for line in lines if line.startswith("Description:"))
description = description.replace("Description:", "").strip()
return {"title": title, "description": description}
def gradio_app_with_docs(func: Callable) -> Callable:
sig = inspect.signature(func)
type_hints = get_type_hints(func)
metadata = parse_docstring(func) # Assuming you have a docstring parser
def _map_type(t: type) -> "gr.Component":
if t == str:
return gr.Textbox(label="Input")
elif t == int:
return gr.Number(precision=0)
elif t == float:
return gr.Number()
elif t == bool:
return gr.Checkbox()
elif hasattr(t, "__origin__") and t.__origin__ == list:
elem_type = getattr(t, "__args__", (Any,))[0]
if elem_type == str:
return gr.Dropdown(choices=["Option1", "Option2"])
else:
raise ValueError(f"Unsupported list element type: {elem_type}")
elif getattr(t, "__origin__", None) == tuple:
args = getattr(t, "__args__", ())
if len(args) == 2:
first_type = args[0]
second_type = args[1]
# Handle int and np.ndarray -- common in TTS for (sample_rate, waveform)
try:
if (
issubclass(first_type, int) and
(hasattr(second_type, "__module__") and second_type.__module__ == "numpy")
):
return gr.Audio(label="Output", type="numpy")
except TypeError:
pass
raise ValueError(f"Unsupported type: {t}")
# Build inputs
inputs = []
for name, param in sig.parameters.items():
if name == "self":
continue
param_type = type_hints.get(name, Any)
component = _map_type(param_type)
component.label = name.replace("_", " ").title()
inputs.append(component)
# Build outputs
return_type = type_hints.get("return", Any)
outputs = _map_type(return_type)
# Wrap with Gradio interface
with gr.Blocks() as demo:
gr.Markdown(f"## {metadata['title']}\n{metadata['description']}")
gr.Interface(fn=func, inputs=inputs, outputs=outputs)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
wrapper.launch = lambda: demo.launch()
return wrapper
class Lexicon:
def __init__(self, lexion_filename: str, tokens_filename: str):
tokens = dict()
with open(tokens_filename, encoding="utf-8") as f:
for line in f:
s, i = line.split()
tokens[s] = int(i)
lexicon = dict()
with open(lexion_filename, encoding="utf-8") as f:
for line in f:
splits = line.split()
word_or_phrase = splits[0]
phone_tone_list = splits[1:]
assert len(phone_tone_list) & 1 == 0, len(phone_tone_list)
phones = phone_tone_list[: len(phone_tone_list) // 2]
phones = [tokens[p] for p in phones]
tones = phone_tone_list[len(phone_tone_list) // 2 :]
tones = [int(t) for t in tones]
lexicon[word_or_phrase] = (phones, tones)
lexicon["呣"] = lexicon["母"]
lexicon["嗯"] = lexicon["恩"]
self.lexicon = lexicon
punctuation = ["!", "?", "…", ",", ".", "'", "-"]
for p in punctuation:
i = tokens[p]
tone = 0
self.lexicon[p] = ([i], [tone])
self.lexicon[" "] = ([tokens["_"]], [0])
def _convert(self, text: str) -> Tuple[List[int], List[int]]:
phones = []
tones = []
if text == ",":
text = ","
elif text == "。":
text = "."
elif text == "!":
text = "!"
elif text == "?":
text = "?"
if text not in self.lexicon:
print("t", text)
if len(text) > 1:
for w in text:
print("w", w)
p, t = self.convert(w)
if p:
phones += p
tones += t
return phones, tones
phones, tones = self.lexicon[text]
return phones, tones
def convert(self, text_list: Iterable[str]) -> Tuple[List[int], List[int]]:
phones = []
tones = []
for text in text_list:
print(text)
p, t = self._convert(text)
phones += p
tones += t
return phones, tones
class OnnxModel:
def __init__(self, filename):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 4
self.session_opts = session_opts
self.model = ort.InferenceSession(
filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
meta = self.model.get_modelmeta().custom_metadata_map
self.bert_dim = int(meta["bert_dim"])
self.ja_bert_dim = int(meta["ja_bert_dim"])
self.add_blank = int(meta["add_blank"])
self.sample_rate = int(meta["sample_rate"])
self.speaker_id = int(meta["speaker_id"])
self.lang_id = int(meta["lang_id"])
self.sample_rate = int(meta["sample_rate"])
def __call__(self, x, tones):
"""
Args:
x: 1-D int64 torch tensor
tones: 1-D int64 torch tensor
"""
x = x.unsqueeze(0)
tones = tones.unsqueeze(0)
print(x.shape, tones.shape)
sid = torch.tensor([self.speaker_id], dtype=torch.int64)
noise_scale = torch.tensor([0.6], dtype=torch.float32)
length_scale = torch.tensor([1.0], dtype=torch.float32)
noise_scale_w = torch.tensor([0.8], dtype=torch.float32)
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.int64)
y = self.model.run(
["y"],
{
"x": x.numpy(),
"x_lengths": x_lengths.numpy(),
"tones": tones.numpy(),
"sid": sid.numpy(),
"noise_scale": noise_scale.numpy(),
"noise_scale_w": noise_scale_w.numpy(),
"length_scale": length_scale.numpy(),
},
)[0][0][0]
return y
model = OnnxModel(local_folder_path / "model.onnx")
lexicon = Lexicon(lexion_filename= local_folder_path / "lexicon.txt", tokens_filename= local_folder_path / "tokens.txt")
@gradio_app_with_docs
def tts(text: str) -> tuple[int, np.ndarray]:
"""
Title: MeloTTS Onnx on CPUU
Description: A Simple app to test MeloTTS Chinese Mix English on CPU.
Args:
prompt (str): A simple prompt.
Returns:
str: Simplified response.
"""
text = text.lower() # this step is crutial for split words correctly
tokenizer = jieba3.jieba3(use_hmm=True).cut_text(text)
phones, tones = lexicon.convert(tokenizer)
if model.add_blank:
new_phones = [0] * (2 * len(phones) + 1)
new_tones = [0] * (2 * len(tones) + 1)
new_phones[1::2] = phones
new_tones[1::2] = tones
phones = new_phones
tones = new_tones
phones = torch.tensor(phones, dtype=torch.int64)
tones = torch.tensor(tones, dtype=torch.int64)
print(phones.shape, tones.shape)
y = model(x=phones, tones=tones)
# sf.write(local_folder_path / "test.wav", y, model.sample_rate)
return (model.sample_rate, y)
if __name__ == "__main__":
# Launch the Gradio app
tts.launch()