biosn2 commited on
Commit
1f6a011
·
verified ·
1 Parent(s): 9c1ddda

Upload app0.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app0.py +110 -0
app0.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import threading
4
+ import time
5
+ import sys
6
+ import torch
7
+
8
+ from huggingface_hub import snapshot_download
9
+
10
+ current_dir = os.path.dirname(os.path.abspath(__file__))
11
+ sys.path.append(current_dir)
12
+ sys.path.append(os.path.join(current_dir, "indextts"))
13
+
14
+ import gradio as gr
15
+ from indextts.infer import IndexTTS
16
+ from tools.i18n.i18n import I18nAuto
17
+
18
+ # 设置多语言
19
+ i18n = I18nAuto(language="en")
20
+
21
+ # 下载模型
22
+ MODE = 'local'
23
+ snapshot_download("IndexTeam/IndexTTS-1.5", local_dir="checkpoints")
24
+
25
+ # 自动选择设备:优先 GPU,没有就 CPU
26
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ print(f"🔥 Using device: {device}")
28
+
29
+ # 初始化 TTS (不传 device)
30
+ tts = IndexTTS(model_dir="checkpoints", cfg_path="checkpoints/config.yaml")
31
+
32
+ # 如果 IndexTTS 支持 to() 或 to_device(),切换设备
33
+ if hasattr(tts, "to"):
34
+ tts.to(device)
35
+ elif hasattr(tts, "to_device"):
36
+ tts.to_device(device)
37
+ else:
38
+ print("⚠️ IndexTTS 没有 to()/to_device() 方法,可能内部已自动处理设备。")
39
+
40
+ # 确保必要的目录存在
41
+ os.makedirs("outputs/tasks", exist_ok=True)
42
+ os.makedirs("prompts", exist_ok=True)
43
+
44
+ # 推理函数
45
+ def infer(voice, text, output_path=None):
46
+ if not tts:
47
+ raise Exception("Model not loaded")
48
+ if not output_path:
49
+ output_path = os.path.join("outputs", f"spk_{int(time.time())}.wav")
50
+ tts.infer(voice, text, output_path)
51
+ return output_path
52
+
53
+ # API 调用
54
+ def tts_api(voice, text):
55
+ try:
56
+ output_path = infer(voice, text)
57
+ with open(output_path, "rb") as f:
58
+ audio_bytes = f.read()
59
+ return (200, {}, audio_bytes)
60
+ except Exception as e:
61
+ return (500, {"error": str(e)}, None)
62
+
63
+ # 单次生成
64
+ def gen_single(prompt, text):
65
+ output_path = infer(prompt, text)
66
+ return gr.update(value=output_path, visible=True)
67
+
68
+ # 上传后启用按钮
69
+ def update_prompt_audio():
70
+ update_button = gr.update(interactive=True)
71
+ return update_button
72
+
73
+ # Gradio 界面
74
+ with gr.Blocks() as demo:
75
+ mutex = threading.Lock()
76
+ gr.HTML(f'''
77
+ <h2><center>IndexTTS WebUI</center></h2>
78
+ <p align="center">当前设备: <b>{device}</b></p>
79
+ ''')
80
+ with gr.Tab("音频生成"):
81
+ with gr.Row():
82
+ os.makedirs("prompts", exist_ok=True)
83
+ prompt_audio = gr.Audio(
84
+ label="请上传参考音频",
85
+ key="prompt_audio",
86
+ sources=["upload", "microphone"],
87
+ type="filepath"
88
+ )
89
+ prompt_list = os.listdir("prompts")
90
+ default = ''
91
+ if prompt_list:
92
+ default = prompt_list[0]
93
+ input_text_single = gr.Textbox(label="请输入目标文本", key="input_text_single")
94
+ gen_button = gr.Button("生成语音", key="gen_button", interactive=True)
95
+ output_audio = gr.Audio(label="生成结果", visible=False, key="output_audio")
96
+
97
+ prompt_audio.upload(update_prompt_audio,
98
+ inputs=[],
99
+ outputs=[gen_button])
100
+
101
+ gen_button.click(gen_single,
102
+ inputs=[prompt_audio, input_text_single],
103
+ outputs=[output_audio])
104
+
105
+ def main():
106
+ tts.load_normalizer()
107
+ demo.launch(server_name="0.0.0.0", server_port=7860)
108
+
109
+ if __name__ == "__main__":
110
+ main()