Spaces:
Sleeping
Sleeping
| import os | |
| CPU_THREADS = 16 | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| os.environ["OMP_NUM_THREADS"] = str(CPU_THREADS) | |
| os.environ["MKL_NUM_THREADS"] = str(CPU_THREADS) | |
| os.environ["OPENBLAS_NUM_THREADS"] = str(CPU_THREADS) | |
| os.environ["NUMEXPR_NUM_THREADS"] = str(CPU_THREADS) | |
| import sys | |
| import tempfile | |
| import gradio as gr | |
| import numpy as np | |
| import soundfile as sf | |
| from huggingface_hub import snapshot_download | |
| MODEL_REPO = "KevinAHM/pocket-tts-onnx" | |
| repo_dir = snapshot_download(repo_id=MODEL_REPO) | |
| os.chdir(repo_dir) | |
| sys.path.insert(0, repo_dir) | |
| import onnxruntime as ort | |
| _OriginalInferenceSession = ort.InferenceSession | |
| def _PatchedInferenceSession(*args, **kwargs): | |
| so = kwargs.get("sess_options", ort.SessionOptions()) | |
| so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL | |
| so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL | |
| so.intra_op_num_threads = CPU_THREADS | |
| so.inter_op_num_threads = 1 | |
| kwargs["sess_options"] = so | |
| return _OriginalInferenceSession(*args, **kwargs) | |
| ort.InferenceSession = _PatchedInferenceSession | |
| from pocket_tts_onnx import PocketTTSOnnx | |
| tts_cache = {} | |
| def get_tts(temperature: float, lsd_steps: int): | |
| key = (float(temperature), int(lsd_steps)) | |
| if key not in tts_cache: | |
| tts_cache[key] = PocketTTSOnnx( | |
| precision="int8", | |
| temperature=float(temperature), | |
| lsd_steps=int(lsd_steps), | |
| device="cpu", | |
| ) | |
| return tts_cache[key] | |
| def synthesize(ref_audio_path, text, temperature, lsd_steps): | |
| text = (text or "").strip() | |
| if not ref_audio_path: | |
| raise gr.Error("Upload a reference audio file.") | |
| if not text: | |
| raise gr.Error("Enter some text.") | |
| tts = get_tts(temperature, int(lsd_steps)) | |
| audio = tts.generate(text=text, voice=ref_audio_path) | |
| sr = getattr(tts, "SAMPLE_RATE", 24000) | |
| audio_np = np.asarray(audio) | |
| if audio_np.ndim > 1: | |
| audio_np = audio_np.squeeze() | |
| out_path = os.path.join(tempfile.gettempdir(), "pocket_tts_out.wav") | |
| sf.write(out_path, audio_np, sr) | |
| return out_path | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| ref_audio = gr.Audio(type="filepath") | |
| text = gr.Textbox(lines=6, value="Hello, this is a test.") | |
| with gr.Row(): | |
| temperature = gr.Slider(0.1, 1.2, value=0.7, step=0.05) | |
| lsd_steps = gr.Slider(1, 20, value=10, step=1) | |
| generate = gr.Button("Generate") | |
| out_audio = gr.Audio(type="filepath") | |
| generate.click( | |
| fn=synthesize, | |
| inputs=[ref_audio, text, temperature, lsd_steps], | |
| outputs=[out_audio], | |
| api_name="generate", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |