simler commited on
Commit
cbc36f5
·
verified ·
1 Parent(s): 422c754

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -40
app.py CHANGED
@@ -1,13 +1,9 @@
1
  import os
2
  import sys
 
3
 
4
- # ==========================================
5
- # 1. 核心补丁:屏蔽 CUDA 和 FlashAttn (防崩)
6
- # ==========================================
7
- os.environ["CUDA_VISIBLE_DEVICES"] = ""
8
- # 欺骗系统:Flash Attention 没安装
9
- sys.modules["flash_attn"] = None
10
-
11
  import torch
12
  torch.cuda.is_available = lambda: False
13
  torch.cuda.device_count = lambda: 0
@@ -15,24 +11,29 @@ def no_op(self, *args, **kwargs): return self
15
  torch.Tensor.cuda = no_op
16
  torch.nn.Module.cuda = no_op
17
 
18
- print("💉 环境手术完成: CUDA/FlashAttn 已屏蔽。")
19
 
20
- # ==========================================
21
- # 2. 导入原版逻辑
22
- # ==========================================
23
  sys.path.append(os.getcwd())
24
 
25
  try:
26
- # 你的目录里就在根目录下,直接导入!
27
  import inference_webui as core
28
  print("✅ 成功导入 inference_webui")
 
 
 
 
 
 
 
 
 
 
29
  except ImportError:
30
- print("❌ 找不到 inference_webui.py,请检查 Files")
31
  sys.exit(1)
32
 
33
- # ==========================================
34
- # 3. 自动加载模型
35
- # ==========================================
36
  def find_real_model(pattern, search_path="."):
37
  candidates = []
38
  for root, dirs, files in os.walk(search_path):
@@ -50,27 +51,31 @@ def find_real_model(pattern, search_path="."):
50
  gpt_path = find_real_model("s1v3.ckpt") or find_real_model("s1bert")
51
  sovits_path = find_real_model("s2Gv2ProPlus.pth") or find_real_model("s2G")
52
 
53
- if gpt_path and sovits_path:
54
- try:
55
- if hasattr(core, "change_gpt_weights"): core.change_gpt_weights(gpt_path=gpt_path)
56
- if hasattr(core, "change_sovits_weights"): core.change_sovits_weights(sovits_path=sovits_path)
57
- print(f"🎉 模型加载成功!\nGPT: {gpt_path}\nSoVITS: {sovits_path}")
58
- except Exception as e:
59
- print(f"⚠️ 模型加载报错: {e}")
60
- else:
61
- print("❌ 未找到模型文件")
 
 
 
 
 
 
62
 
63
- # ==========================================
64
- # 4. 修复语言参数的推理函数
65
- # ==========================================
66
  import soundfile as sf
67
  import gradio as gr
68
  import numpy as np
69
 
70
  REF_AUDIO = "ref.wav"
71
  REF_TEXT = "你好"
72
- # 🔴 关键修正:必须是汉字 "中文",不能是 "zh"
73
- LANG_SETTING = "中文"
74
 
75
  def run_predict(text):
76
  if not os.path.exists(REF_AUDIO):
@@ -78,20 +83,18 @@ def run_predict(text):
78
 
79
  print(f"📥 任务: {text}")
80
  try:
81
- # 自动识别可用函数
82
  inference_func = getattr(core, "get_tts_model", getattr(core, "get_tts_wav", None))
83
  if not inference_func:
84
  return None, "❌ 找不到推理函数"
85
 
86
  # 核心调用
87
- # 这里我们将 prompt_language 和 text_language 都设为 "中文"
88
- # 这就是之前 KeyError: 'zh' 的解法
89
  generator = inference_func(
90
  ref_wav_path=REF_AUDIO,
91
  prompt_text=REF_TEXT,
92
- prompt_language=LANG_SETTING, # <--- 改成了 "中文"
93
  text=text,
94
- text_language=LANG_SETTING, # <--- 改成了 "中文"
95
  how_to_cut="凑四句一切",
96
  top_k=5, top_p=1, temperature=1, ref_free=False
97
  )
@@ -109,14 +112,12 @@ def run_predict(text):
109
  traceback.print_exc()
110
  return None, f"💥 报错: {e}"
111
 
112
- # ==========================================
113
- # 5. 启动界面
114
- # ==========================================
115
  with gr.Blocks() as app:
116
- gr.Markdown(f"### GPT-SoVITS V2 (CPU Worker)")
117
 
118
  with gr.Row():
119
- inp = gr.Textbox(label="文本", value="这次肯定没问题了,开始批量生产吧。")
120
  btn = gr.Button("生成")
121
 
122
  with gr.Row():
 
1
  import os
2
  import sys
3
+ import logging
4
 
5
+ # --- 1. 基础环境设置 ---
6
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
 
 
 
 
 
7
  import torch
8
  torch.cuda.is_available = lambda: False
9
  torch.cuda.device_count = lambda: 0
 
11
  torch.Tensor.cuda = no_op
12
  torch.nn.Module.cuda = no_op
13
 
14
+ print("💉 CUDA 补丁已注入")
15
 
16
+ # --- 2. 导入核心逻辑 ---
 
 
17
  sys.path.append(os.getcwd())
18
 
19
  try:
 
20
  import inference_webui as core
21
  print("✅ 成功导入 inference_webui")
22
+
23
+ # 🛑 关键修正:强制关闭半精度 (CPU 不支持 FP16,也能顺便避开 Flash Attention)
24
+ if hasattr(core, "is_half"):
25
+ core.is_half = False
26
+ print("✅ 强制禁用半精度 (is_half = False)")
27
+
28
+ if hasattr(core, "device"):
29
+ core.device = "cpu"
30
+ print("✅ 强制指定设备 (device = cpu)")
31
+
32
  except ImportError:
33
+ print("❌ 严重错误:找不到 inference_webui.py")
34
  sys.exit(1)
35
 
36
+ # --- 3. 自动寻找模型 ---
 
 
37
  def find_real_model(pattern, search_path="."):
38
  candidates = []
39
  for root, dirs, files in os.walk(search_path):
 
51
  gpt_path = find_real_model("s1v3.ckpt") or find_real_model("s1bert")
52
  sovits_path = find_real_model("s2Gv2ProPlus.pth") or find_real_model("s2G")
53
 
54
+ # --- 4. 加载模型 ---
55
+ try:
56
+ if gpt_path and sovits_path:
57
+ # 再次确保加载模型时不会启用半精度
58
+ core.is_half = False
59
+
60
+ if hasattr(core, "change_gpt_weights"):
61
+ core.change_gpt_weights(gpt_path=gpt_path)
62
+ if hasattr(core, "change_sovits_weights"):
63
+ core.change_sovits_weights(sovits_path=sovits_path)
64
+ print(f"🎉 模型加载成功!")
65
+ else:
66
+ print("❌ 未找到模型文件")
67
+ except Exception as e:
68
+ print(f"⚠️ 模型加载报错: {e}")
69
 
70
+ # --- 5. 推理逻辑 ---
 
 
71
  import soundfile as sf
72
  import gradio as gr
73
  import numpy as np
74
 
75
  REF_AUDIO = "ref.wav"
76
  REF_TEXT = "你好"
77
+ # 🛑 关键修正:语言必须是中文
78
+ REF_LANG = "中文"
79
 
80
  def run_predict(text):
81
  if not os.path.exists(REF_AUDIO):
 
83
 
84
  print(f"📥 任务: {text}")
85
  try:
86
+ # 自动识别函数
87
  inference_func = getattr(core, "get_tts_model", getattr(core, "get_tts_wav", None))
88
  if not inference_func:
89
  return None, "❌ 找不到推理函数"
90
 
91
  # 核心调用
 
 
92
  generator = inference_func(
93
  ref_wav_path=REF_AUDIO,
94
  prompt_text=REF_TEXT,
95
+ prompt_language=REF_LANG, # 中文
96
  text=text,
97
+ text_language="中文", # 中文
98
  how_to_cut="凑四句一切",
99
  top_k=5, top_p=1, temperature=1, ref_free=False
100
  )
 
112
  traceback.print_exc()
113
  return None, f"💥 报错: {e}"
114
 
115
+ # --- 6. 界面 ---
 
 
116
  with gr.Blocks() as app:
117
+ gr.Markdown(f"### GPT-SoVITS V2 (Final CPU)")
118
 
119
  with gr.Row():
120
+ inp = gr.Textbox(label="文本", value="这次一定行,不行我就吃键盘。")
121
  btn = gr.Button("生成")
122
 
123
  with gr.Row():