simler commited on
Commit
ad40ef8
·
verified ·
1 Parent(s): 5af22ee

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -0
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ # --- 1. CPU 强制补丁 (关键!) ---
5
+ # 这几行代码必须放在所有 import 之前
6
+ # 它的作用是告诉程序:“别找显卡了,我就用 CPU”
7
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
8
+ import torch
9
+ torch.cuda.is_available = lambda : False
10
+ device = "cpu"
11
+ is_half = False # CPU 不支持半精度,必须强制 False
12
+
13
+ # --- 2. 路径适配 ---
14
+ now_dir = os.getcwd()
15
+ sys.path.append(now_dir)
16
+
17
+ # 引入核心库
18
+ import gradio as gr
19
+ import soundfile as sf
20
+ import numpy as np
21
+ from tools.i18n.i18n import I18nAuto
22
+ from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_model
23
+
24
+ # 初始化
25
+ i18n = I18nAuto()
26
+
27
+ # --- 3. 模型配置 (请确保你上传了这些文件!) ---
28
+ # 这里你需要改成你上传的实际文件名
29
+ # 如果你还没上传,先用它自带的模型测试一下也行(如果有的话)
30
+ GPT_MODEL_PATH = "Jiang-GPT.ckpt" # <--- 你的 GPT 模型名
31
+ SOVITS_MODEL_PATH = "Jiang-SoVITS.pth" # <--- 你的 SoVITS 模型名
32
+ REF_AUDIO_PATH = "ref_jiang.wav" # <--- 你的参考音频名
33
+ REF_TEXT = "你好,我是江学姐。" # <--- 参考音频对应的字
34
+ REF_LANG = "zh" # <--- 参考音频语言
35
+
36
+ # --- 4. 加载模型 ---
37
+ print("🔄 正在初始化 CPU 推理环境...")
38
+
39
+ try:
40
+ # 尝试加载模型,如果文件不存在会报错,但 Space 不会崩,只会打印错误
41
+ if os.path.exists(GPT_MODEL_PATH) and os.path.exists(SOVITS_MODEL_PATH):
42
+ change_gpt_weights(gpt_path=GPT_MODEL_PATH)
43
+ change_sovits_weights(sovits_path=SOVITS_MODEL_PATH)
44
+ print("✅ 自定义模型加载成功!")
45
+ else:
46
+ print(f"⚠️ 未找到自定义模型 ({GPT_MODEL_PATH}),请在 Files 中上传!")
47
+ print("⚠️ 目前将尝试使用环境默认模型(如果存在)...")
48
+ except Exception as e:
49
+ print(f"❌ 模型加载出现小问题 (可忽略): {e}")
50
+
51
+ # --- 5. 推理函数 ---
52
+ def predict_worker(text, text_lang="zh"):
53
+ if not text: return None, "请输入文本"
54
+
55
+ print(f"📥 CPU 处理任务: {text[:10]}...")
56
+
57
+ # 检查参考音频是否存在
58
+ if not os.path.exists(REF_AUDIO_PATH):
59
+ return None, f"错误:找不到参考音频 {REF_AUDIO_PATH},请上传!"
60
+
61
+ try:
62
+ # 核心推理
63
+ generator = get_tts_model(
64
+ ref_wav_path=REF_AUDIO_PATH,
65
+ prompt_text=REF_TEXT,
66
+ prompt_language=REF_LANG,
67
+ text=text,
68
+ text_language=text_lang,
69
+ how_to_cut="凑四句一切",
70
+ top_k=5,
71
+ top_p=1.0,
72
+ temperature=1.0,
73
+ ref_free=False
74
+ )
75
+
76
+ result_list = list(generator)
77
+ if result_list:
78
+ sampling_rate, audio_data = result_list[0]
79
+ output_path = f"output_{os.urandom(4).hex()}.wav"
80
+ sf.write(output_path, audio_data, sampling_rate)
81
+ return output_path, "生成成功"
82
+
83
+ except Exception as e:
84
+ return None, f"报错了: {str(e)}\n(可能是内存不够或模型不匹配)"
85
+
86
+ # --- 6. 界面 ---
87
+ with gr.Blocks() as app:
88
+ gr.Markdown("# 🧩 小说 TTS 节点 (CPU 版)")
89
+
90
+ with gr.Row():
91
+ inp = gr.Textbox(label="文本")
92
+ out = gr.Audio(label="结果")
93
+ msg = gr.Textbox(label="日志")
94
+
95
+ btn = gr.Button("生成")
96
+ btn.click(predict_worker, [inp], [out, msg], api_name="predict")
97
+
98
+ if __name__ == "__main__":
99
+ app.queue().launch()