File size: 2,685 Bytes
ffccc5e
4b61f47
 
 
 
 
 
 
 
 
ffccc5e
 
 
 
 
 
 
 
 
29082ed
 
 
ffccc5e
a09d229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ffccc5e
 
29082ed
ffccc5e
3d73fc7
 
29082ed
4b61f47
3d73fc7
4b61f47
 
 
 
29082ed
ffccc5e
3d73fc7
ffccc5e
 
 
 
 
 
3d73fc7
ffccc5e
 
4b61f47
ffccc5e
 
 
 
 
 
3d73fc7
ffccc5e
 
 
3d73fc7
 
ffccc5e
3d73fc7
 
 
 
ffccc5e
 
 
3d73fc7
 
ffccc5e
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os

CPU_THREADS = 16

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["OMP_NUM_THREADS"] = str(CPU_THREADS)
os.environ["MKL_NUM_THREADS"] = str(CPU_THREADS)
os.environ["OPENBLAS_NUM_THREADS"] = str(CPU_THREADS)
os.environ["NUMEXPR_NUM_THREADS"] = str(CPU_THREADS)

import sys
import tempfile
import gradio as gr
import numpy as np
import soundfile as sf
from huggingface_hub import snapshot_download

MODEL_REPO = "KevinAHM/pocket-tts-onnx"

repo_dir = snapshot_download(repo_id=MODEL_REPO)
os.chdir(repo_dir)
sys.path.insert(0, repo_dir)

import onnxruntime as ort

_OriginalInferenceSession = ort.InferenceSession

def _PatchedInferenceSession(*args, **kwargs):
    so = kwargs.get("sess_options", ort.SessionOptions())
    so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
    so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
    so.intra_op_num_threads = CPU_THREADS
    so.inter_op_num_threads = 1
    kwargs["sess_options"] = so
    return _OriginalInferenceSession(*args, **kwargs)

ort.InferenceSession = _PatchedInferenceSession

from pocket_tts_onnx import PocketTTSOnnx

tts_cache = {}

def get_tts(temperature: float, lsd_steps: int):
    key = (float(temperature), int(lsd_steps))
    if key not in tts_cache:
        tts_cache[key] = PocketTTSOnnx(
            precision="int8",
            temperature=float(temperature),
            lsd_steps=int(lsd_steps),
            device="cpu",
        )
    return tts_cache[key]

def synthesize(ref_audio_path, text, temperature, lsd_steps):
    text = (text or "").strip()
    if not ref_audio_path:
        raise gr.Error("Upload a reference audio file.")
    if not text:
        raise gr.Error("Enter some text.")

    tts = get_tts(temperature, int(lsd_steps))
    audio = tts.generate(text=text, voice=ref_audio_path)

    sr = getattr(tts, "SAMPLE_RATE", 24000)
    audio_np = np.asarray(audio)
    if audio_np.ndim > 1:
        audio_np = audio_np.squeeze()

    out_path = os.path.join(tempfile.gettempdir(), "pocket_tts_out.wav")
    sf.write(out_path, audio_np, sr)
    return out_path

with gr.Blocks() as demo:
    with gr.Row():
        ref_audio = gr.Audio(type="filepath")
        text = gr.Textbox(lines=6, value="Hello, this is a test.")
    with gr.Row():
        temperature = gr.Slider(0.1, 1.2, value=0.7, step=0.05)
        lsd_steps = gr.Slider(1, 20, value=10, step=1)
    generate = gr.Button("Generate")
    out_audio = gr.Audio(type="filepath")

    generate.click(
        fn=synthesize,
        inputs=[ref_audio, text, temperature, lsd_steps],
        outputs=[out_audio],
        api_name="generate",
    )

if __name__ == "__main__":
    demo.launch()