leesenx commited on
Commit
df4432b
·
verified ·
1 Parent(s): 0923232

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -81
app.py CHANGED
@@ -1,82 +1,64 @@
1
- import os
2
- import gradio as gr
3
- import numpy as np
4
- import soundfile as sf
5
- import tempfile
6
-
7
- from huggingface_hub import snapshot_download
8
- import onnxruntime as ort
9
-
10
-
11
- # =========================
12
- # 1. 自动下载模型(关键)
13
- # =========================
14
- MODEL_DIR = snapshot_download(
15
- repo_id="OpenMOSS-Team/MOSS-TTS-Nano-100M-ONNX",
16
- local_dir="./models",
17
- local_dir_use_symlinks=False
18
- )
19
-
20
-
21
- # =========================
22
- # 2. ONNX TTS 封装(简化可运行结构)
23
- # =========================
24
- class MOSSTTS:
25
- def __init__(self, model_dir):
26
- self.prefill = ort.InferenceSession(
27
- f"{model_dir}/moss_tts_prefill.onnx",
28
- providers=["CPUExecutionProvider"]
29
- )
30
-
31
- self.decode = ort.InferenceSession(
32
- f"{model_dir}/moss_tts_decode_step.onnx",
33
- providers=["CPUExecutionProvider"]
34
- )
35
-
36
- def infer(self, text):
37
- """
38
- ⚠️ 注意:这里是最小可跑demo结构
39
- 实际项目需要 tokenizer + codec
40
- """
41
-
42
- # fake token(占位)
43
- input_ids = np.array([[1, 2, 3]], dtype=np.int64)
44
-
45
- self.prefill.run(None, {"input_ids": input_ids})
46
-
47
- # fake audio
48
- wav = np.random.randn(16000 * 3).astype(np.float32)
49
- sr = 16000
50
 
51
- return wav, sr
52
-
53
-
54
- # =========================
55
- # 3. 初始化模型
56
- # =========================
57
- tts = MOSSTTS(MODEL_DIR)
58
-
59
-
60
- # =========================
61
- # 4. 推理函数
62
- # =========================
63
- def generate(text):
64
- wav, sr = tts.infer(text)
65
-
66
- out_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False).name
67
- sf.write(out_file, wav, sr)
68
-
69
- return out_file
70
-
71
-
72
- # =========================
73
- # 5. Gradio UI
74
- # =========================
75
- demo = gr.Interface(
76
- fn=generate,
77
- inputs=gr.Textbox(label="Text"),
78
- outputs=gr.Audio(label="Output Audio"),
79
- title="MOSS-TTS-Nano ONNX (CPU)",
80
- )
81
-
82
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os, time, io, wave
2
+ sys.path.insert(0, "/app")
3
+ os.environ["OMP_NUM_THREADS"] = "2"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ import numpy as np
6
+ import gradio as gr
7
+ from onnx_tts_runtime import OnnxTtsRuntime, _merge_audio_channels, _concat_waveforms, _write_waveform_to_wav
8
+
9
+ MODEL_DIR = "/app/models"
10
+ BUILTIN_VOICES = None
11
+ runtime = None
12
+
13
+ def load_runtime():
14
+ global runtime, BUILTIN_VOICES
15
+ if runtime is not None:
16
+ return runtime
17
+ runtime = OnnxTtsRuntime(
18
+ model_dir=MODEL_DIR,
19
+ thread_count=2,
20
+ max_new_frames=375,
21
+ execution_provider="cpu",
22
+ )
23
+ BUILTIN_VOICES = [v["voice"] for v in runtime.list_builtin_voices()]
24
+ return runtime
25
+
26
+ def synthesize(text, voice, audio_path, sample_mode, max_frames):
27
+ rt = load_runtime()
28
+ t0 = time.time()
29
+ result = rt.synthesize(
30
+ text=text,
31
+ voice=voice if not audio_path else None,
32
+ prompt_audio_path=audio_path if audio_path else None,
33
+ sample_mode=sample_mode,
34
+ do_sample=(sample_mode != "greedy"),
35
+ streaming=True,
36
+ max_new_frames=int(max_frames),
37
+ enable_wetext=False,
38
+ enable_normalize_tts_text=False,
39
+ )
40
+ elapsed = time.time() - t0
41
+ sr = result["sample_rate"]
42
+ wav_path = result["audio_path"]
43
+ return wav_path, f"Done in {elapsed:.1f}s | {sr}Hz | {int(result['audio_token_ids'].shape[0])} frames"
44
+
45
+ with gr.Blocks(title="MOSS-TTS-Nano ONNX") as demo:
46
+ gr.Markdown("# MOSS-TTS-Nano-100M-ONNX\nCPU-only TTS with voice cloning. First run downloads ~730MB model.")
47
+ with gr.Row():
48
+ with gr.Column():
49
+ text_in = gr.Textbox(label="Text", value="Hello, welcome to MOSS TTS Nano.", lines=3)
50
+ with gr.Row():
51
+ voice_in = gr.Dropdown(choices=["Junhao","Zhiming","Weiguo","Xiaoyu","Yuewen","Lingyu","Trump","Ava","Bella","Adam","Nathan","Soyo","Saki","Mortis","Umiri","Mei","Anon","Arisa"], value="Junhao", label="Voice (overridden by ref audio)")
52
+ ref_audio = gr.Audio(label="Reference Audio (optional, for voice cloning)", type="filepath")
53
+ with gr.Row():
54
+ sample_mode = gr.Dropdown(choices=["fixed","greedy","full"], value="fixed", label="Sample Mode")
55
+ max_frames = gr.Slider(16, 750, value=375, step=1, label="Max Frames")
56
+ btn = gr.Button("Synthesize", variant="primary")
57
+ with gr.Column():
58
+ audio_out = gr.Audio(label="Generated Audio", type="filepath")
59
+ info_out = gr.Textbox(label="Info")
60
+ btn.click(fn=synthesize, inputs=[text_in, voice_in, ref_audio, sample_mode, max_frames], outputs=[audio_out, info_out])
61
+
62
+ if __name__ == "__main__":
63
+ load_runtime()
64
+ demo.launch()