biosn2 commited on
Commit
9c1ddda
·
verified ·
1 Parent(s): 22c181c

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +183 -68
app.py CHANGED
@@ -1,110 +1,225 @@
 
1
  import os
2
- import shutil
3
  import threading
4
  import time
5
- import sys
6
- import torch
7
 
8
- from huggingface_hub import snapshot_download
 
 
9
 
10
  current_dir = os.path.dirname(os.path.abspath(__file__))
11
  sys.path.append(current_dir)
12
  sys.path.append(os.path.join(current_dir, "indextts"))
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  import gradio as gr
 
15
  from indextts.infer import IndexTTS
16
  from tools.i18n.i18n import I18nAuto
17
 
18
- # 设置多语言
19
- i18n = I18nAuto(language="en")
20
-
21
- # 下载模型
22
  MODE = 'local'
23
- snapshot_download("IndexTeam/IndexTTS-1.5", local_dir="checkpoints")
 
 
 
 
24
 
25
- # 自动选择设备:优先 GPU,没有就 CPU
26
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
- print(f"🔥 Using device: {device}")
28
-
29
- # 初始化 TTS (不传 device)
30
- tts = IndexTTS(model_dir="checkpoints", cfg_path="checkpoints/config.yaml")
31
-
32
- # 如果 IndexTTS 支持 to() 或 to_device(),切换设备
33
- if hasattr(tts, "to"):
34
- tts.to(device)
35
- elif hasattr(tts, "to_device"):
36
- tts.to_device(device)
37
- else:
38
- print("⚠️ IndexTTS 没有 to()/to_device() 方法,可能内部已自动处理设备。")
39
-
40
- # 确保必要的目录存在
41
- os.makedirs("outputs/tasks", exist_ok=True)
42
- os.makedirs("prompts", exist_ok=True)
43
-
44
- # 推理函数
45
- def infer(voice, text, output_path=None):
46
- if not tts:
47
- raise Exception("Model not loaded")
48
  if not output_path:
49
  output_path = os.path.join("outputs", f"spk_{int(time.time())}.wav")
50
- tts.infer(voice, text, output_path)
51
- return output_path
52
-
53
- # API 调用
54
- def tts_api(voice, text):
55
- try:
56
- output_path = infer(voice, text)
57
- with open(output_path, "rb") as f:
58
- audio_bytes = f.read()
59
- return (200, {}, audio_bytes)
60
- except Exception as e:
61
- return (500, {"error": str(e)}, None)
62
-
63
- # 单次生成
64
- def gen_single(prompt, text):
65
- output_path = infer(prompt, text)
66
- return gr.update(value=output_path, visible=True)
67
-
68
- # 上传后启用按钮
 
 
 
 
 
 
 
 
 
69
  def update_prompt_audio():
70
  update_button = gr.update(interactive=True)
71
  return update_button
72
 
73
- # Gradio 界面
74
- with gr.Blocks() as demo:
75
  mutex = threading.Lock()
76
- gr.HTML(f'''
77
- <h2><center>IndexTTS WebUI</center></h2>
78
- <p align="center">当前设备: <b>{device}</b></p>
 
 
 
79
  ''')
80
  with gr.Tab("音频生成"):
81
  with gr.Row():
82
- os.makedirs("prompts", exist_ok=True)
83
- prompt_audio = gr.Audio(
84
- label="请上传参考音频",
85
- key="prompt_audio",
86
- sources=["upload", "microphone"],
87
- type="filepath"
88
- )
89
  prompt_list = os.listdir("prompts")
90
  default = ''
91
  if prompt_list:
92
  default = prompt_list[0]
93
- input_text_single = gr.Textbox(label="请输入目标文本", key="input_text_single")
94
- gen_button = gr.Button("生成语音", key="gen_button", interactive=True)
95
- output_audio = gr.Audio(label="生成结果", visible=False, key="output_audio")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  prompt_audio.upload(update_prompt_audio,
98
  inputs=[],
99
  outputs=[gen_button])
100
 
101
  gen_button.click(gen_single,
102
- inputs=[prompt_audio, input_text_single],
 
 
 
103
  outputs=[output_audio])
104
 
 
 
 
 
 
 
 
 
 
105
  def main():
106
  tts.load_normalizer()
107
  demo.launch(server_name="0.0.0.0", server_port=7860)
108
 
109
  if __name__ == "__main__":
110
- main()
 
1
+ import json
2
  import os
3
+ import sys
4
  import threading
5
  import time
 
 
6
 
7
+ import warnings
8
+ warnings.filterwarnings("ignore", category=FutureWarning)
9
+ warnings.filterwarnings("ignore", category=UserWarning)
10
 
11
  current_dir = os.path.dirname(os.path.abspath(__file__))
12
  sys.path.append(current_dir)
13
  sys.path.append(os.path.join(current_dir, "indextts"))
14
 
15
+ import argparse
16
+ parser = argparse.ArgumentParser(description="IndexTTS WebUI")
17
+ parser.add_argument("--verbose", action="store_true", default=False, help="Enable verbose mode")
18
+ parser.add_argument("--port", type=int, default=7860, help="Port to run the web UI on")
19
+ parser.add_argument("--host", type=str, default="127.0.0.1", help="Host to run the web UI on")
20
+ parser.add_argument("--model_dir", type=str, default="checkpoints", help="Model checkpoints directory")
21
+ cmd_args = parser.parse_args()
22
+ MODE = 'local'
23
+ snapshot_download("IndexTeam/IndexTTS-1.5", local_dir="checkpoints")
24
+
25
+ if not os.path.exists(cmd_args.model_dir):
26
+ print(f"Model directory {cmd_args.model_dir} does not exist. Please download the model first.")
27
+ sys.exit(1)
28
+
29
+ for file in [
30
+ "bigvgan_generator.pth",
31
+ "bpe.model",
32
+ "gpt.pth",
33
+ "config.yaml",
34
+ ]:
35
+ file_path = os.path.join(cmd_args.model_dir, file)
36
+ if not os.path.exists(file_path):
37
+ print(f"Required file {file_path} does not exist. Please download it.")
38
+ sys.exit(1)
39
+
40
  import gradio as gr
41
+
42
  from indextts.infer import IndexTTS
43
  from tools.i18n.i18n import I18nAuto
44
 
45
+ i18n = I18nAuto(language="zh_CN")
 
 
 
46
  MODE = 'local'
47
+ tts = IndexTTS(model_dir=cmd_args.model_dir, cfg_path=os.path.join(cmd_args.model_dir, "config.yaml"),)
48
+
49
+
50
+ os.makedirs("outputs/tasks",exist_ok=True)
51
+ os.makedirs("prompts",exist_ok=True)
52
 
53
+ with open("tests/cases.jsonl", "r", encoding="utf-8") as f:
54
+ example_cases = []
55
+ for line in f:
56
+ line = line.strip()
57
+ if not line:
58
+ continue
59
+ example = json.loads(line)
60
+ example_cases.append([os.path.join("tests", example.get("prompt_audio", "sample_prompt.wav")),
61
+ example.get("text"), ["普通推理", "批次推理"][example.get("infer_mode", 0)]])
62
+
63
+ def gen_single(prompt, text, infer_mode, max_text_tokens_per_sentence=120, sentences_bucket_max_size=4,
64
+ *args, progress=gr.Progress()):
65
+ output_path = None
 
 
 
 
 
 
 
 
 
 
66
  if not output_path:
67
  output_path = os.path.join("outputs", f"spk_{int(time.time())}.wav")
68
+ # set gradio progress
69
+ tts.gr_progress = progress
70
+ do_sample, top_p, top_k, temperature, \
71
+ length_penalty, num_beams, repetition_penalty, max_mel_tokens = args
72
+ kwargs = {
73
+ "do_sample": bool(do_sample),
74
+ "top_p": float(top_p),
75
+ "top_k": int(top_k) if int(top_k) > 0 else None,
76
+ "temperature": float(temperature),
77
+ "length_penalty": float(length_penalty),
78
+ "num_beams": num_beams,
79
+ "repetition_penalty": float(repetition_penalty),
80
+ "max_mel_tokens": int(max_mel_tokens),
81
+ # "typical_sampling": bool(typical_sampling),
82
+ # "typical_mass": float(typical_mass),
83
+ }
84
+ if infer_mode == "普通推理":
85
+ output = tts.infer(prompt, text, output_path, verbose=cmd_args.verbose,
86
+ max_text_tokens_per_sentence=int(max_text_tokens_per_sentence),
87
+ **kwargs)
88
+ else:
89
+ # 批次推理
90
+ output = tts.infer_fast(prompt, text, output_path, verbose=cmd_args.verbose,
91
+ max_text_tokens_per_sentence=int(max_text_tokens_per_sentence),
92
+ sentences_bucket_max_size=(sentences_bucket_max_size),
93
+ **kwargs)
94
+ return gr.update(value=output,visible=True)
95
+
96
  def update_prompt_audio():
97
  update_button = gr.update(interactive=True)
98
  return update_button
99
 
100
+ with gr.Blocks(title="IndexTTS Demo") as demo:
 
101
  mutex = threading.Lock()
102
+ gr.HTML('''
103
+ <h2><center>IndexTTS: An Industrial-Level Controllable and Efficient Zero-Shot Text-To-Speech System</h2>
104
+ <h2><center>(一款工业级可控且高效的零样本文本转语音系统)</h2>
105
+ <p align="center">
106
+ <a href='https://arxiv.org/abs/2502.05512'><img src='https://img.shields.io/badge/ArXiv-2502.05512-red'></a>
107
+ </p>
108
  ''')
109
  with gr.Tab("音频生成"):
110
  with gr.Row():
111
+ os.makedirs("prompts",exist_ok=True)
112
+ prompt_audio = gr.Audio(label="参考音频",key="prompt_audio",
113
+ sources=["upload","microphone"],type="filepath")
 
 
 
 
114
  prompt_list = os.listdir("prompts")
115
  default = ''
116
  if prompt_list:
117
  default = prompt_list[0]
118
+ with gr.Column():
119
+ input_text_single = gr.TextArea(label="文本",key="input_text_single", placeholder="请输入目标文本", info="当前模型版本{}".format(tts.model_version or "1.0"))
120
+ infer_mode = gr.Radio(choices=["普通推理", "批次推理"], label="推理模式",info="批次推理:更适合长句,性能翻倍",value="普通推理")
121
+ gen_button = gr.Button("生成语音", key="gen_button",interactive=True)
122
+ output_audio = gr.Audio(label="生成结果", visible=True,key="output_audio")
123
+ with gr.Accordion("高级生成参数设置", open=False):
124
+ with gr.Row():
125
+ with gr.Column(scale=1):
126
+ gr.Markdown("**GPT2 采样设置** _参数会影响音频多样性和生成速度详见[Generation strategies](https://huggingface.co/docs/transformers/main/en/generation_strategies)_")
127
+ with gr.Row():
128
+ do_sample = gr.Checkbox(label="do_sample", value=True, info="是否进行采样")
129
+ temperature = gr.Slider(label="temperature", minimum=0.1, maximum=2.0, value=1.0, step=0.1)
130
+ with gr.Row():
131
+ top_p = gr.Slider(label="top_p", minimum=0.0, maximum=1.0, value=0.8, step=0.01)
132
+ top_k = gr.Slider(label="top_k", minimum=0, maximum=100, value=30, step=1)
133
+ num_beams = gr.Slider(label="num_beams", value=3, minimum=1, maximum=10, step=1)
134
+ with gr.Row():
135
+ repetition_penalty = gr.Number(label="repetition_penalty", precision=None, value=10.0, minimum=0.1, maximum=20.0, step=0.1)
136
+ length_penalty = gr.Number(label="length_penalty", precision=None, value=0.0, minimum=-2.0, maximum=2.0, step=0.1)
137
+ max_mel_tokens = gr.Slider(label="max_mel_tokens", value=600, minimum=50, maximum=tts.cfg.gpt.max_mel_tokens, step=10, info="生成Token最大数量,过小导致音频被截断", key="max_mel_tokens")
138
+ # with gr.Row():
139
+ # typical_sampling = gr.Checkbox(label="typical_sampling", value=False, info="不建议使用")
140
+ # typical_mass = gr.Slider(label="typical_mass", value=0.9, minimum=0.0, maximum=1.0, step=0.1)
141
+ with gr.Column(scale=2):
142
+ gr.Markdown("**分句设置** _参数会影响音频质量和生成速度_")
143
+ with gr.Row():
144
+ max_text_tokens_per_sentence = gr.Slider(
145
+ label="分句最大Token数", value=120, minimum=20, maximum=tts.cfg.gpt.max_text_tokens, step=2, key="max_text_tokens_per_sentence",
146
+ info="建议80~200之间,值越大,分句越长;值越小,分句越碎;过小过大都可能导致音频质量不高",
147
+ )
148
+ sentences_bucket_max_size = gr.Slider(
149
+ label="分句分桶的最大容量(批次推理生效)", value=4, minimum=1, maximum=16, step=1, key="sentences_bucket_max_size",
150
+ info="建议2-8之间,值越大,一批次推理包含的分句数越多,过大可能导致内存溢出",
151
+ )
152
+ with gr.Accordion("预览分句结果", open=True) as sentences_settings:
153
+ sentences_preview = gr.Dataframe(
154
+ headers=["序号", "分句内容", "Token数"],
155
+ key="sentences_preview",
156
+ wrap=True,
157
+ )
158
+ advanced_params = [
159
+ do_sample, top_p, top_k, temperature,
160
+ length_penalty, num_beams, repetition_penalty, max_mel_tokens,
161
+ # typical_sampling, typical_mass,
162
+ ]
163
+
164
+ if len(example_cases) > 0:
165
+ gr.Examples(
166
+ examples=example_cases,
167
+ inputs=[prompt_audio, input_text_single, infer_mode],
168
+ )
169
 
170
+ def on_input_text_change(text, max_tokens_per_sentence):
171
+ if text and len(text) > 0:
172
+ text_tokens_list = tts.tokenizer.tokenize(text)
173
+
174
+ sentences = tts.tokenizer.split_sentences(text_tokens_list, max_tokens_per_sentence=int(max_tokens_per_sentence))
175
+ data = []
176
+ for i, s in enumerate(sentences):
177
+ sentence_str = ''.join(s)
178
+ tokens_count = len(s)
179
+ data.append([i, sentence_str, tokens_count])
180
+
181
+ return {
182
+ sentences_preview: gr.update(value=data, visible=True, type="array"),
183
+ }
184
+ else:
185
+ df = pd.DataFrame([], columns=["序号", "分句内容", "Token数"])
186
+ return {
187
+ sentences_preview: gr.update(value=df)
188
+ }
189
+
190
+ input_text_single.change(
191
+ on_input_text_change,
192
+ inputs=[input_text_single, max_text_tokens_per_sentence],
193
+ outputs=[sentences_preview]
194
+ )
195
+ max_text_tokens_per_sentence.change(
196
+ on_input_text_change,
197
+ inputs=[input_text_single, max_text_tokens_per_sentence],
198
+ outputs=[sentences_preview]
199
+ )
200
  prompt_audio.upload(update_prompt_audio,
201
  inputs=[],
202
  outputs=[gen_button])
203
 
204
  gen_button.click(gen_single,
205
+ inputs=[prompt_audio, input_text_single, infer_mode,
206
+ max_text_tokens_per_sentence, sentences_bucket_max_size,
207
+ *advanced_params,
208
+ ],
209
  outputs=[output_audio])
210
 
211
+
212
+
213
+
214
+ # if __name__ == "__main__":
215
+ # demo.queue(20)
216
+ # demo.launch(server_name=cmd_args.host, server_port=cmd_args.port)
217
+
218
+
219
+
220
  def main():
221
  tts.load_normalizer()
222
  demo.launch(server_name="0.0.0.0", server_port=7860)
223
 
224
  if __name__ == "__main__":
225
+ main()