jiangluohan / app.py
simler's picture
Update app.py
6577c56 verified
import os
import sys
import requests
import shutil
# ==========================================
# 1. 基础环境净化
# ==========================================
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
import torch
torch.cuda.is_available = lambda: False
torch.cuda.device_count = lambda: 0
def no_op(self, *args, **kwargs): return self
torch.Tensor.cuda = no_op
torch.nn.Module.cuda = no_op
print("💉 CUDA 已屏蔽")
# ==========================================
# 2. 饱和式救援:下载全套缺失零件
# ==========================================
print("🚚 启动饱和式空投,正在重建 CPU 环境...")
BASE_URL = "https://raw.githubusercontent.com/RVC-Boss/GPT-SoVITS/main/GPT_SoVITS"
# 定义所有需要修复的文件清单
FILES_TO_PATCH = [
# 核心:把 CPU 模型代码写入 GPU 文件
{
"url": f"{BASE_URL}/AR/models/t2s_model.py",
"path": "AR/models/t2s_model_flash_attn.py",
},
# 依赖 1: Utils
{
"url": f"{BASE_URL}/AR/models/utils.py",
"path": "AR/models/utils.py",
},
# 依赖 2: Embedding
{
"url": f"{BASE_URL}/AR/modules/embedding.py",
"path": "AR/modules/embedding.py",
},
# 依赖 3: Transformer (你刚刚报错缺这个)
{
"url": f"{BASE_URL}/AR/modules/transformer.py",
"path": "AR/modules/transformer.py",
},
# 依赖 4: Attention (为了防止还没报错就先补上)
{
"url": f"{BASE_URL}/AR/modules/attention.py",
"path": "AR/modules/attention.py",
},
# 依赖 5: Commons (保险起见)
{
"url": f"{BASE_URL}/AR/modules/commons.py",
"path": "AR/modules/commons.py",
}
]
for item in FILES_TO_PATCH:
try:
# 1. 确保目录存在
dir_name = os.path.dirname(item["path"])
if not os.path.exists(dir_name):
os.makedirs(dir_name, exist_ok=True)
# 补 init
with open(os.path.join(dir_name, "__init__.py"), "w") as f: f.write("")
# 2. 下载
print(f"⬇️ 下载补丁: {os.path.basename(item['path'])} ...")
resp = requests.get(item["url"], timeout=10)
if resp.status_code == 200:
with open(item["path"], "w", encoding="utf-8") as f:
f.write(resp.text)
print(f"✅ 修复成功: {item['path']}")
else:
print(f"❌ 下载失败 ({resp.status_code}): {item['url']}")
except Exception as e:
print(f"⚠️ 文件处理错误: {e}")
# 补充根目录 init
if os.path.exists("AR") and not os.path.exists("AR/__init__.py"):
with open("AR/__init__.py", "w") as f: f.write("")
# ==========================================
# 3. 导入核心逻辑
# ==========================================
sys.path.append(os.getcwd())
try:
import inference_webui as core
print("✅ 成功导入 inference_webui")
if hasattr(core, "is_half"): core.is_half = False
if hasattr(core, "device"): core.device = "cpu"
except Exception as e:
print(f"❌ 导入失败: {e}")
sys.exit(1)
# ==========================================
# 4. 自动寻找模型
# ==========================================
def find_model_file(pattern):
for root, dirs, files in os.walk("."):
for file in files:
if pattern in file and not file.endswith(".lock") and not file.endswith(".metadata"):
path = os.path.join(root, file)
size_mb = os.path.getsize(path) / (1024 * 1024)
if size_mb > 10: return path
return None
gpt_path = find_model_file("s1v3.ckpt") or find_model_file("s1bert")
sovits_path = find_model_file("s2Gv2ProPlus.pth") or find_model_file("s2G")
# ==========================================
# 5. 加载模型
# ==========================================
try:
if gpt_path and sovits_path:
core.is_half = False
if hasattr(core, "change_gpt_weights"):
core.change_gpt_weights(gpt_path=gpt_path)
if hasattr(core, "change_sovits_weights"):
core.change_sovits_weights(sovits_path=sovits_path)
print(f"🎉 模型加载成功!(CPU Rebuilt)")
else:
print("❌ 未找到模型文件")
except Exception as e:
print(f"⚠️ 模型加载报错: {e}")
# ==========================================
# 6. 推理逻辑
# ==========================================
import soundfile as sf
import gradio as gr
import numpy as np
REF_AUDIO = "ref.wav"
REF_TEXT = "你好"
REF_LANG = "中文"
def run_predict(text):
if not os.path.exists(REF_AUDIO): return None, "❌ 请上传 ref.wav"
print(f"📥 任务: {text}")
try:
inference_func = getattr(core, "get_tts_model", getattr(core, "get_tts_wav", None))
if not inference_func: return None, "❌ 找不到推理函数"
generator = inference_func(
ref_wav_path=REF_AUDIO,
prompt_text=REF_TEXT,
prompt_language=REF_LANG,
text=text,
text_language="中文",
how_to_cut="凑四句一切",
top_k=5, top_p=1, temperature=1, ref_free=False
)
result_list = list(generator)
if result_list:
sr, data = result_list[0]
out_path = f"out_{os.urandom(4).hex()}.wav"
sf.write(out_path, data, sr)
print(f"✅ 生成完毕: {out_path}")
return out_path, "✅ 成功"
except Exception as e:
import traceback
traceback.print_exc()
return None, f"💥 报错: {e}"
# ==========================================
# 7. 界面
# ==========================================
with gr.Blocks() as app:
gr.Markdown(f"### GPT-SoVITS V2 (Full Repair)")
with gr.Row():
inp = gr.Textbox(label="文本", value="所有零件都补齐了,这次一定行。")
btn = gr.Button("生成")
with gr.Row():
out = gr.Audio(label="结果")
log = gr.Textbox(label="日志")
btn.click(run_predict, [inp], [out, log], api_name="predict")
if __name__ == "__main__":
app.queue().launch()