simler commited on
Commit
87bde7f
·
verified ·
1 Parent(s): 9ee61c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -65
app.py CHANGED
@@ -2,49 +2,42 @@ import os
2
  import sys
3
 
4
  # ==========================================
5
- # 🛑 核心屏蔽补丁 (必须放在最最前面)
6
  # ==========================================
7
-
8
- # 1. 屏蔽 CUDA (显卡)
9
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
10
-
11
- # 2. 屏蔽 Flash Attention (关键!防崩核心)
12
- # 我们直接把这个模块设为 None,假装没安装
13
- # 这样 GPT-SoVITS 就会回退到普通 CPU 模式
14
- sys.modules["flash_attn"] = None
15
-
16
  import torch
17
-
18
- # 3. 彻底欺骗 Torch
19
  torch.cuda.is_available = lambda: False
20
  torch.cuda.device_count = lambda: 0
21
  def no_op(self, *args, **kwargs): return self
22
  torch.Tensor.cuda = no_op
23
  torch.nn.Module.cuda = no_op
24
-
25
- print("💉 环境手术完成: CUDA已移除, FlashAttn已禁用。")
26
 
27
  # ==========================================
28
- # 🚀 业务逻辑
29
  # ==========================================
30
- sys.path.append(os.getcwd())
 
 
 
 
31
 
32
- # 导入推理核心
33
  try:
34
- import inference_webui as core
35
- print("✅ 成功导入 inference_webui")
36
- except ImportError:
37
- print("❌ 找不到 inference_webui.py")
 
 
 
 
 
 
38
  sys.exit(1)
39
 
40
- # 自动寻找推理函数
41
- inference_func = None
42
- if hasattr(core, "get_tts_model"):
43
- inference_func = core.get_tts_model
44
- elif hasattr(core, "get_tts_wav"):
45
- inference_func = core.get_tts_wav
46
-
47
- # 自动寻找模型
48
  def find_real_model(pattern, search_path="."):
49
  candidates = []
50
  for root, dirs, files in os.walk(search_path):
@@ -52,84 +45,116 @@ def find_real_model(pattern, search_path="."):
52
  if pattern in file and not file.endswith(".lock") and not file.endswith(".metadata"):
53
  path = os.path.join(root, file)
54
  size_mb = os.path.getsize(path) / (1024 * 1024)
55
- if size_mb > 10:
56
  candidates.append((path, size_mb))
57
  if candidates:
58
  candidates.sort(key=lambda x: x[1], reverse=True)
59
- print(f"✅ 选中模型: {candidates[0][0]}")
60
  return candidates[0][0]
61
  return None
62
 
63
- gpt_path = find_real_model("s1v3.ckpt")
64
- if not gpt_path: gpt_path = find_real_model("s1bert")
 
 
 
 
65
 
66
- sovits_path = find_real_model("s2Gv2ProPlus.pth")
67
- if not sovits_path: sovits_path = find_real_model("s2G")
 
 
68
 
69
- # 加载模型
70
  try:
71
- if gpt_path and sovits_path:
72
- # 强制设置 config 为非半精度 (CPU不支持 half)
73
- # 这也是为了防止 Flash Attn 被错误触发
74
- if hasattr(core, "is_half"): core.is_half = False
 
 
 
 
 
 
75
 
76
- if hasattr(core, "change_gpt_weights"):
77
- core.change_gpt_weights(gpt_path=gpt_path)
78
- if hasattr(core, "change_sovits_weights"):
79
- core.change_sovits_weights(sovits_path=sovits_path)
80
- print("🎉 模型加载完成 (CPU模式)!")
 
 
 
 
 
81
  except Exception as e:
82
- print(f"⚠️ 模型加载报错: {e}")
83
 
84
- # 推理逻辑
 
 
85
  import soundfile as sf
86
  import gradio as gr
87
  import numpy as np
88
 
89
  REF_AUDIO = "ref.wav"
90
  REF_TEXT = "你好"
91
- REF_LANG = "中文" # 必须是中文
92
 
93
  def run_predict(text):
 
 
 
94
  if not os.path.exists(REF_AUDIO):
95
- return None, "❌ 错误:请上传 ref.wav"
96
 
97
  print(f"📥 任务: {text}")
98
  try:
99
- # 核心推理
100
- generator = inference_func(
101
- ref_wav_path=REF_AUDIO,
102
- prompt_text=REF_TEXT,
103
- prompt_language=REF_LANG,
104
- text=text,
105
- text_language="中文",
106
- how_to_cut="凑四句一切",
107
- top_k=5, top_p=1, temperature=1, ref_free=False
108
- )
 
 
 
 
 
 
 
109
 
 
110
  result_list = list(generator)
 
111
  if result_list:
112
  sr, data = result_list[0]
113
  out_path = f"out_{os.urandom(4).hex()}.wav"
114
  sf.write(out_path, data, sr)
115
- print(f"✅ 生成完毕: {out_path}")
116
- return out_path, "✅ 成功"
117
 
118
  except Exception as e:
119
  import traceback
120
  traceback.print_exc()
121
- return None, f"💥 报错: {e}"
122
 
123
- # 界面
 
 
124
  with gr.Blocks() as app:
125
- gr.Markdown(f"### GPT-SoVITS V2 (纯CPU版)")
 
 
126
 
127
  with gr.Row():
128
- inp = gr.Textbox(label="文本", value="终于成功了,这次一定能响。")
129
  btn = gr.Button("生成")
130
 
131
  with gr.Row():
132
- out = gr.Audio(label="音频")
133
  log = gr.Textbox(label="日志")
134
 
135
  btn.click(run_predict, [inp], [out, log], api_name="predict")
 
2
  import sys
3
 
4
  # ==========================================
5
+ # 1. 净化环境 (防止 GPU 报错)
6
  # ==========================================
 
 
7
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
 
 
 
 
 
 
8
  import torch
 
 
9
  torch.cuda.is_available = lambda: False
10
  torch.cuda.device_count = lambda: 0
11
  def no_op(self, *args, **kwargs): return self
12
  torch.Tensor.cuda = no_op
13
  torch.nn.Module.cuda = no_op
14
+ print("💉 CUDA 已屏蔽,强制 CPU 模式")
 
15
 
16
  # ==========================================
17
+ # 2. 导入核心引擎 (不再依赖 webui)
18
  # ==========================================
19
+ cwd = os.getcwd()
20
+ sys.path.append(cwd)
21
+ sys.path.append(os.path.join(cwd, "GPT_SoVITS")) # 把子目录加入路径,防止找不到模块
22
+
23
+ print("📂 正在尝试导入核心引擎...")
24
 
 
25
  try:
26
+ # 尝试多种路径导入,总有一个是对的
27
+ try:
28
+ from TTS_infer_pack.TTS import TTS, TTS_Config
29
+ except ImportError:
30
+ from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
31
+ print("✅ 成功连接到底层 TTS 引擎!")
32
+ except ImportError as e:
33
+ print(f"❌ 核心引擎导入失败: {e}")
34
+ # 如果这里失败了,打印一下目录帮我找原因
35
+ print("目录结构:", os.listdir("."))
36
  sys.exit(1)
37
 
38
+ # ==========================================
39
+ # 3. 自动寻找模型 (智能模式)
40
+ # ==========================================
 
 
 
 
 
41
  def find_real_model(pattern, search_path="."):
42
  candidates = []
43
  for root, dirs, files in os.walk(search_path):
 
45
  if pattern in file and not file.endswith(".lock") and not file.endswith(".metadata"):
46
  path = os.path.join(root, file)
47
  size_mb = os.path.getsize(path) / (1024 * 1024)
48
+ if size_mb > 10: # 大于10MB才是真模型
49
  candidates.append((path, size_mb))
50
  if candidates:
51
  candidates.sort(key=lambda x: x[1], reverse=True)
 
52
  return candidates[0][0]
53
  return None
54
 
55
+ gpt_path = find_real_model("s1v3.ckpt") or find_real_model("s1bert")
56
+ sovits_path = find_real_model("s2Gv2ProPlus.pth") or find_real_model("s2G")
57
+
58
+ if not gpt_path or not sovits_path:
59
+ print("❌ 严重错误:没找到模型文件!请检查 Logs 下载进度。")
60
+ # 为了防止直接退出,这里不 sys.exit,让界面能显示出来报错
61
 
62
+ # ==========================================
63
+ # 4. 初始化引擎
64
+ # ==========================================
65
+ tts_pipeline = None
66
 
 
67
  try:
68
+ # 寻找配置文件
69
+ config_path = "GPT_SoVITS/configs/tts_infer.yaml"
70
+ if not os.path.exists(config_path):
71
+ config_path = "configs/tts_infer.yaml"
72
+
73
+ if os.path.exists(config_path):
74
+ print(f"⚙️ 加载配置: {config_path}")
75
+ tts_config = TTS_Config(config_path)
76
+ tts_config.device = "cpu"
77
+ tts_config.is_half = False
78
 
79
+ if gpt_path and sovits_path:
80
+ tts_config.t2s_weights_path = gpt_path
81
+ tts_config.vits_weights_path = sovits_path
82
+
83
+ # 启动引擎!
84
+ tts_pipeline = TTS(tts_config)
85
+ print("🚀 引擎启动成功!(Ready to Generate)")
86
+ else:
87
+ print("❌ 找不到 tts_infer.yaml 配置文件")
88
+
89
  except Exception as e:
90
+ print(f"⚠️ 引擎初始化异常: {e}")
91
 
92
+ # ==========================================
93
+ # 5. 定义接口
94
+ # ==========================================
95
  import soundfile as sf
96
  import gradio as gr
97
  import numpy as np
98
 
99
  REF_AUDIO = "ref.wav"
100
  REF_TEXT = "你好"
101
+ REF_LANG = "zh"
102
 
103
  def run_predict(text):
104
+ if tts_pipeline is None:
105
+ return None, "❌ 错误:引擎未启动 (模型或配置缺失)"
106
+
107
  if not os.path.exists(REF_AUDIO):
108
+ return None, "❌ 错误:根目录下没找到 ref.wav,请上传!"
109
 
110
  print(f"📥 任务: {text}")
111
  try:
112
+ # 手动构造请求参数
113
+ req = {
114
+ "text": text,
115
+ "text_lang": "zh",
116
+ "ref_audio_path": REF_AUDIO,
117
+ "prompt_text": REF_TEXT,
118
+ "prompt_lang": REF_LANG,
119
+ "top_k": 5, "top_p": 1, "temperature": 1,
120
+ "text_split_method": "cut4",
121
+ "batch_size": 1,
122
+ "speed_factor": 1.0,
123
+ "fragment_interval": 0.3,
124
+ "seed": -1,
125
+ "return_fragment": False,
126
+ "parallel_infer": True,
127
+ "repetition_penalty": 1.35
128
+ }
129
 
130
+ generator = tts_pipeline.run(req)
131
  result_list = list(generator)
132
+
133
  if result_list:
134
  sr, data = result_list[0]
135
  out_path = f"out_{os.urandom(4).hex()}.wav"
136
  sf.write(out_path, data, sr)
137
+ return out_path, "✅ 生成成功"
 
138
 
139
  except Exception as e:
140
  import traceback
141
  traceback.print_exc()
142
+ return None, f"💥 引擎报错: {e}"
143
 
144
+ # ==========================================
145
+ # 6. 启动界面
146
+ # ==========================================
147
  with gr.Blocks() as app:
148
+ gr.Markdown("### GPT-SoVITS V2 (Direct Core)")
149
+ gr.Markdown(f"GPT: `{os.path.basename(gpt_path) if gpt_path else '❌'}`")
150
+ gr.Markdown(f"SoVITS: `{os.path.basename(sovits_path) if sovits_path else '❌'}`")
151
 
152
  with gr.Row():
153
+ inp = gr.Textbox(label="文本", value="这下总该可以了吧!")
154
  btn = gr.Button("生成")
155
 
156
  with gr.Row():
157
+ out = gr.Audio(label="结果")
158
  log = gr.Textbox(label="日志")
159
 
160
  btn.click(run_predict, [inp], [out, log], api_name="predict")