biosn2 commited on
Commit
fda2eb7
·
verified ·
1 Parent(s): 05dad15

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +57 -132
app.py CHANGED
@@ -1,11 +1,8 @@
1
- import json
2
  import os
3
  import sys
4
- import threading
5
  import time
6
  import subprocess
7
-
8
- from huggingface_hub import snapshot_download
9
  import warnings
10
  warnings.filterwarnings("ignore", category=FutureWarning)
11
  warnings.filterwarnings("ignore", category=UserWarning)
@@ -14,7 +11,7 @@ import argparse
14
  parser = argparse.ArgumentParser(description="IndexTTS WebUI")
15
  parser.add_argument("--verbose", action="store_true", default=False, help="Enable verbose mode")
16
  parser.add_argument("--port", type=int, default=7860, help="Port to run the web UI on")
17
- parser.add_argument("--host", type=str, default="127.0.0.1", help="Host to run the web UI on")
18
  parser.add_argument("--model_dir", type=str, default="checkpoints", help="Model checkpoints directory")
19
  cmd_args = parser.parse_args()
20
 
@@ -22,164 +19,92 @@ current_dir = os.path.dirname(os.path.abspath(__file__))
22
  sys.path.append(current_dir)
23
  sys.path.append(os.path.join(current_dir, "indextts"))
24
 
25
- MODE = 'local'
26
- snapshot_download("IndexTeam/IndexTTS-1.5", local_dir="checkpoints")
27
-
28
- if not os.path.exists(cmd_args.model_dir):
29
- print(f"Model directory {cmd_args.model_dir} does not exist. Please download the model first.")
30
- sys.exit(1)
31
 
32
- for file in [
33
- "bigvgan_generator.pth",
34
- "bpe.model",
35
- "gpt.pth",
36
- "config.yaml",
37
- ]:
38
  file_path = os.path.join(cmd_args.model_dir, file)
39
  if not os.path.exists(file_path):
40
- print(f"Required file {file_path} does not exist. Please download it.")
41
- sys.exit(1)
42
 
 
43
  import gradio as gr
44
- import pandas as pd
45
-
46
  from indextts.infer import IndexTTS
47
- from tools.i18n.i18n import I18nAuto
48
-
49
- i18n = I18nAuto(language="zh_CN")
50
- tts = IndexTTS(model_dir=cmd_args.model_dir, cfg_path=os.path.join(cmd_args.model_dir, "config.yaml"),)
51
 
52
- os.makedirs("outputs/tasks", exist_ok=True)
53
- os.makedirs("prompts", exist_ok=True)
54
 
55
- # ----------------- 核心修改:保证 WAV 格式 & 打印进度 -----------------
56
  def ensure_wav(file_path):
57
  """将非 WAV 音频转换为 WAV"""
58
- if not file_path.lower().endswith(".wav"):
59
  wav_path = file_path.rsplit(".", 1)[0] + ".wav"
60
  subprocess.run(["ffmpeg", "-y", "-i", file_path, wav_path], check=True)
61
  return wav_path
62
  return file_path
63
 
64
  def progress_print(step, total, info=""):
65
- """生成音频进度打印到终端"""
66
  percent = int(step / total * 100)
67
  print(f"\r[{percent}%] {info}", end="", flush=True)
68
 
69
- def gen_single(prompt, text, infer_mode, max_text_tokens_per_sentence=120, sentences_bucket_max_size=4,
70
- *args, progress=gr.Progress()):
71
- prompt = ensure_wav(prompt)
72
- output_path = os.path.join("outputs", f"spk_{int(time.time())}.wav")
73
- tts.gr_progress = progress
74
- tts.print_progress = progress_print # 将进度打印到终端
75
-
76
- do_sample, top_p, top_k, temperature, \
77
- length_penalty, num_beams, repetition_penalty, max_mel_tokens = args
78
 
79
  kwargs = {
80
  "do_sample": bool(do_sample),
81
  "top_p": float(top_p),
82
  "top_k": int(top_k) if int(top_k) > 0 else None,
83
  "temperature": float(temperature),
84
- "length_penalty": float(length_penalty),
85
- "num_beams": num_beams,
86
  "repetition_penalty": float(repetition_penalty),
 
87
  "max_mel_tokens": int(max_mel_tokens),
88
  }
89
 
90
- if infer_mode == "普通推理":
91
- output = tts.infer(prompt, text, output_path, verbose=cmd_args.verbose,
92
- max_text_tokens_per_sentence=int(max_text_tokens_per_sentence),
93
- **kwargs)
94
- else:
95
- output = tts.infer_fast(prompt, text, output_path, verbose=cmd_args.verbose,
96
- max_text_tokens_per_sentence=int(max_text_tokens_per_sentence),
97
- sentences_bucket_max_size=int(sentences_bucket_max_size),
98
- **kwargs)
99
-
100
- print("\n生成完成:", output_path)
101
- return gr.update(value=output, visible=True)
102
-
103
- def update_prompt_audio():
104
- return gr.update(interactive=True)
105
-
106
- # ----------------- Gradio UI -----------------
107
  with gr.Blocks(title="IndexTTS Demo") as demo:
108
- mutex = threading.Lock()
109
- gr.HTML('''
110
- <h2><center>IndexTTS: An Industrial-Level Controllable and Efficient Zero-Shot Text-To-Speech System</h2>
111
- <h2><center>(一款工业级可控且高效的零样本文本转语音系统)</h2>
112
- ''')
113
- with gr.Tab("音频生成"):
 
 
114
  with gr.Row():
115
- os.makedirs("prompts", exist_ok=True)
116
- prompt_audio = gr.Audio(label="参考音频", key="prompt_audio",
117
- sources=["upload","microphone"], type="filepath")
118
- with gr.Column():
119
- input_text_single = gr.TextArea(label="文本", key="input_text_single", placeholder="请输入目标文本", info="当前模型版本{}".format(tts.model_version or "1.0"))
120
- infer_mode = gr.Radio(choices=["普通推理", "批次推理"], label="推理模式", info="批次推理:更适合长句,性能翻倍", value="普通推理")
121
- gen_button = gr.Button("生成语音", key="gen_button", interactive=True)
122
- output_audio = gr.Audio(label="生成结果", visible=True, key="output_audio")
123
- with gr.Accordion("高级生成参数设置", open=False):
124
- with gr.Row():
125
- with gr.Column(scale=1):
126
- gr.Markdown("**GPT2 采样设置** _参数会影响音频多样性和生成速度_")
127
- with gr.Row():
128
- do_sample = gr.Checkbox(label="do_sample", value=True)
129
- temperature = gr.Slider(label="temperature", minimum=0.1, maximum=2.0, value=1.0, step=0.1)
130
- with gr.Row():
131
- top_p = gr.Slider(label="top_p", minimum=0.0, maximum=1.0, value=0.8, step=0.01)
132
- top_k = gr.Slider(label="top_k", minimum=0, maximum=100, value=30, step=1)
133
- num_beams = gr.Slider(label="num_beams", value=3, minimum=1, maximum=10, step=1)
134
- with gr.Row():
135
- repetition_penalty = gr.Number(label="repetition_penalty", precision=None, value=10.0, minimum=0.1, maximum=20.0, step=0.1)
136
- length_penalty = gr.Number(label="length_penalty", precision=None, value=0.0, minimum=-2.0, maximum=2.0, step=0.1)
137
- max_mel_tokens = gr.Slider(label="max_mel_tokens", value=600, minimum=50, maximum=tts.cfg.gpt.max_mel_tokens, step=10)
138
- with gr.Column(scale=2):
139
- gr.Markdown("**分句设置**")
140
- with gr.Row():
141
- max_text_tokens_per_sentence = gr.Slider(label="分句最大Token数", value=120, minimum=20, maximum=tts.cfg.gpt.max_text_tokens, step=2)
142
- sentences_bucket_max_size = gr.Slider(label="分句分桶的最大容量", value=4, minimum=1, maximum=16, step=1)
143
- with gr.Accordion("预览分句结果", open=True) as sentences_settings:
144
- sentences_preview = gr.Dataframe(headers=["序号", "分句内容", "Token数"], key="sentences_preview", wrap=True)
145
-
146
- advanced_params = [
147
- do_sample, top_p, top_k, temperature,
148
- length_penalty, num_beams, repetition_penalty, max_mel_tokens,
149
- ]
150
-
151
- input_text_single.change(
152
- lambda text, max_tokens_per_sentence: {
153
- sentences_preview: gr.update(value=[
154
- [i, ''.join(s), len(s)] for i, s in enumerate(
155
- tts.tokenizer.split_sentences(tts.tokenizer.tokenize(text), int(max_tokens_per_sentence))
156
- )
157
- ]) if text else gr.update(value=pd.DataFrame([], columns=["序号","分句内容","Token数"]))
158
- },
159
- inputs=[input_text_single, max_text_tokens_per_sentence],
160
- outputs=[sentences_preview]
161
- )
162
- max_text_tokens_per_sentence.change(
163
- lambda text, max_tokens_per_sentence: {
164
- sentences_preview: gr.update(value=[
165
- [i, ''.join(s), len(s)] for i, s in enumerate(
166
- tts.tokenizer.split_sentences(tts.tokenizer.tokenize(text), int(max_tokens_per_sentence))
167
- )
168
- ]) if text else gr.update(value=pd.DataFrame([], columns=["序号","分句内容","Token数"]))
169
- },
170
- inputs=[input_text_single, max_text_tokens_per_sentence],
171
- outputs=[sentences_preview]
172
  )
173
- prompt_audio.upload(update_prompt_audio, inputs=[], outputs=[gen_button])
174
-
175
- gen_button.click(gen_single,
176
- inputs=[prompt_audio, input_text_single, infer_mode,
177
- max_text_tokens_per_sentence, sentences_bucket_max_size,
178
- *advanced_params],
179
- outputs=[output_audio])
180
-
181
- def main():
182
- demo.launch(server_name="0.0.0.0", server_port=cmd_args.port)
183
 
 
184
  if __name__ == "__main__":
185
- main()
 
 
1
  import os
2
  import sys
 
3
  import time
4
  import subprocess
5
+ import threading
 
6
  import warnings
7
  warnings.filterwarnings("ignore", category=FutureWarning)
8
  warnings.filterwarnings("ignore", category=UserWarning)
 
11
  parser = argparse.ArgumentParser(description="IndexTTS WebUI")
12
  parser.add_argument("--verbose", action="store_true", default=False, help="Enable verbose mode")
13
  parser.add_argument("--port", type=int, default=7860, help="Port to run the web UI on")
14
+ parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the web UI on")
15
  parser.add_argument("--model_dir", type=str, default="checkpoints", help="Model checkpoints directory")
16
  cmd_args = parser.parse_args()
17
 
 
19
  sys.path.append(current_dir)
20
  sys.path.append(os.path.join(current_dir, "indextts"))
21
 
22
+ # --- 检查模型 ---
23
+ from huggingface_hub import snapshot_download
24
+ snapshot_download("IndexTeam/IndexTTS-1.5", local_dir=cmd_args.model_dir)
 
 
 
25
 
26
+ for file in ["bigvgan_generator.pth","bpe.model","gpt.pth","config.yaml"]:
 
 
 
 
 
27
  file_path = os.path.join(cmd_args.model_dir, file)
28
  if not os.path.exists(file_path):
29
+ raise FileNotFoundError(f"{file_path} 不存在,请下载模型")
 
30
 
31
+ # --- 导入模块 ---
32
  import gradio as gr
 
 
33
  from indextts.infer import IndexTTS
 
 
 
 
34
 
35
+ tts = IndexTTS(model_dir=cmd_args.model_dir, cfg_path=os.path.join(cmd_args.model_dir, "config.yaml"))
36
+ os.makedirs("outputs", exist_ok=True)
37
 
38
+ # --- 工具函数 ---
39
  def ensure_wav(file_path):
40
  """将非 WAV 音频转换为 WAV"""
41
+ if file_path and not file_path.lower().endswith(".wav"):
42
  wav_path = file_path.rsplit(".", 1)[0] + ".wav"
43
  subprocess.run(["ffmpeg", "-y", "-i", file_path, wav_path], check=True)
44
  return wav_path
45
  return file_path
46
 
47
  def progress_print(step, total, info=""):
 
48
  percent = int(step / total * 100)
49
  print(f"\r[{percent}%] {info}", end="", flush=True)
50
 
51
+ # --- 普通推理 ---
52
+ def generate_audio(prompt_audio, text,
53
+ do_sample=True, top_p=0.8, top_k=30, temperature=1.0,
54
+ num_beams=3, repetition_penalty=10.0, length_penalty=0.0, max_mel_tokens=600):
55
+ prompt_audio = ensure_wav(prompt_audio)
56
+ output_path = os.path.join("outputs", "out.wav") # 固定输出文件名
 
 
 
57
 
58
  kwargs = {
59
  "do_sample": bool(do_sample),
60
  "top_p": float(top_p),
61
  "top_k": int(top_k) if int(top_k) > 0 else None,
62
  "temperature": float(temperature),
63
+ "num_beams": int(num_beams),
 
64
  "repetition_penalty": float(repetition_penalty),
65
+ "length_penalty": float(length_penalty),
66
  "max_mel_tokens": int(max_mel_tokens),
67
  }
68
 
69
+ tts.print_progress = progress_print # 打印进度到终端
70
+ print(f"\n>> start inference for text: {text}")
71
+ try:
72
+ tts.infer(prompt_audio, text, output_path, verbose=cmd_args.verbose, **kwargs)
73
+ print(f"\n>> generated wav file: {output_path}")
74
+ return output_path
75
+ except Exception as e:
76
+ print(f"\n>> generation failed: {e}")
77
+ return f"生成失败: {e}"
78
+
79
+ # --- Gradio UI ---
 
 
 
 
 
 
80
  with gr.Blocks(title="IndexTTS Demo") as demo:
81
+ gr.Markdown("## IndexTTS - 普通推理 (参考音频必填)")
82
+ with gr.Row():
83
+ prompt_audio = gr.Audio(label="参考音频", source="upload", type="filepath")
84
+ text_input = gr.TextArea(label="文", placeholder="请输入目标文本")
85
+ gen_button = gr.Button("生成语音")
86
+ output_audio = gr.Audio(label="生成结果")
87
+
88
+ with gr.Accordion("高级参数", open=False):
89
  with gr.Row():
90
+ do_sample = gr.Checkbox(label="do_sample", value=True)
91
+ temperature = gr.Slider(label="temperature", minimum=0.1, maximum=2.0, value=1.0, step=0.1)
92
+ with gr.Row():
93
+ top_p = gr.Slider(label="top_p", minimum=0.0, maximum=1.0, value=0.8, step=0.01)
94
+ top_k = gr.Slider(label="top_k", minimum=0, maximum=100, value=30, step=1)
95
+ num_beams = gr.Slider(label="num_beams", minimum=1, maximum=10, value=3, step=1)
96
+ with gr.Row():
97
+ repetition_penalty = gr.Number(label="repetition_penalty", value=10.0, step=0.1)
98
+ length_penalty = gr.Number(label="length_penalty", value=0.0, step=0.1)
99
+ max_mel_tokens = gr.Slider(label="max_mel_tokens", minimum=50, maximum=600, value=600, step=10)
100
+
101
+ gen_button.click(
102
+ generate_audio,
103
+ inputs=[prompt_audio, text_input, do_sample, top_p, top_k, temperature,
104
+ num_beams, repetition_penalty, length_penalty, max_mel_tokens],
105
+ outputs=[output_audio]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  )
 
 
 
 
 
 
 
 
 
 
107
 
108
+ # --- 启动 ---
109
  if __name__ == "__main__":
110
+ demo.launch(server_name=cmd_args.host, server_port=cmd_args.port)