TTS / app.py
StaticFace's picture
Update app.py
3d73fc7 verified
import os
CPU_THREADS = 16
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["OMP_NUM_THREADS"] = str(CPU_THREADS)
os.environ["MKL_NUM_THREADS"] = str(CPU_THREADS)
os.environ["OPENBLAS_NUM_THREADS"] = str(CPU_THREADS)
os.environ["NUMEXPR_NUM_THREADS"] = str(CPU_THREADS)
import sys
import tempfile
import gradio as gr
import numpy as np
import soundfile as sf
from huggingface_hub import snapshot_download
MODEL_REPO = "KevinAHM/pocket-tts-onnx"
repo_dir = snapshot_download(repo_id=MODEL_REPO)
os.chdir(repo_dir)
sys.path.insert(0, repo_dir)
import onnxruntime as ort
_OriginalInferenceSession = ort.InferenceSession
def _PatchedInferenceSession(*args, **kwargs):
so = kwargs.get("sess_options", ort.SessionOptions())
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
so.intra_op_num_threads = CPU_THREADS
so.inter_op_num_threads = 1
kwargs["sess_options"] = so
return _OriginalInferenceSession(*args, **kwargs)
ort.InferenceSession = _PatchedInferenceSession
from pocket_tts_onnx import PocketTTSOnnx
tts_cache = {}
def get_tts(temperature: float, lsd_steps: int):
key = (float(temperature), int(lsd_steps))
if key not in tts_cache:
tts_cache[key] = PocketTTSOnnx(
precision="int8",
temperature=float(temperature),
lsd_steps=int(lsd_steps),
device="cpu",
)
return tts_cache[key]
def synthesize(ref_audio_path, text, temperature, lsd_steps):
text = (text or "").strip()
if not ref_audio_path:
raise gr.Error("Upload a reference audio file.")
if not text:
raise gr.Error("Enter some text.")
tts = get_tts(temperature, int(lsd_steps))
audio = tts.generate(text=text, voice=ref_audio_path)
sr = getattr(tts, "SAMPLE_RATE", 24000)
audio_np = np.asarray(audio)
if audio_np.ndim > 1:
audio_np = audio_np.squeeze()
out_path = os.path.join(tempfile.gettempdir(), "pocket_tts_out.wav")
sf.write(out_path, audio_np, sr)
return out_path
with gr.Blocks() as demo:
with gr.Row():
ref_audio = gr.Audio(type="filepath")
text = gr.Textbox(lines=6, value="Hello, this is a test.")
with gr.Row():
temperature = gr.Slider(0.1, 1.2, value=0.7, step=0.05)
lsd_steps = gr.Slider(1, 20, value=10, step=1)
generate = gr.Button("Generate")
out_audio = gr.Audio(type="filepath")
generate.click(
fn=synthesize,
inputs=[ref_audio, text, temperature, lsd_steps],
outputs=[out_audio],
api_name="generate",
)
if __name__ == "__main__":
demo.launch()