Spaces:
Runtime error
Runtime error
File size: 10,481 Bytes
79f89ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 |
import builtins
import os
import sys
import shutil
import io
import time
import uvicorn
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import StreamingResponse
# 🔴 核心:在所有 import 之前,必须先劫持 input
builtins.input = lambda prompt="": "y"
# 使用本地 genie_tts 源码(而非已安装的包)
# 将当前目录添加到 sys.path,确保优先加载本地模块
current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir not in sys.path:
sys.path.insert(0, current_dir)
# 适配 Space 路径,本地运行时请确保此目录存在
os.environ["GENIE_DATA_DIR"] = "/app/GenieData"
# 注释掉自动下载逻辑,假设 GenieData 已预装在镜像中
# if not os.path.exists("/app/GenieData/G2P"):
# print("📦 Downloading GenieData Assets...")
# from huggingface_hub import snapshot_download
# snapshot_download(repo_id="High-Logic/Genie", allow_patterns=["GenieData/*"], local_dir="/app", local_dir_use_symlinks=False)
import genie_tts
app = FastAPI()
# 角色模型存放根目录
MODELS_ROOT = "/app/models"
os.makedirs(MODELS_ROOT, exist_ok=True)
# 默认设置(加载 models/base 和 models/god)
genie_tts.load_character("Base", "/app/models/base", "zh")
genie_tts.load_character("god", "/app/models/god", "zh")
# 记录每个角色的默认参考音频
REF_CACHE = {
"Base": {
"path": "/app/models/base/ref.wav",
"text": "琴是个称职的好团长。看到她认真工作的样子,就连我也忍不住想要多帮她一把。",
"lang": "zh"
},
"god": {
"path": "/app/models/god/ref.wav",
"text": "很多人的一生,写于纸上也不过几行,大多都是些无聊的故事啊。",
"lang": "zh"
}
}
@app.post("/load_model")
async def load_model(character_name: str = Form(...), model_path: str = Form(...), language: str = Form("zh")):
"""
动态加载新模型 API
model_path: 相对于 /app 的路径,例如 "models/my_character"
"""
full_path = os.path.join("/app", model_path)
if not os.path.exists(full_path):
raise HTTPException(status_code=404, detail=f"Model path not found: {full_path}")
try:
print(f"📦 Loading character: {character_name} from {full_path}")
genie_tts.load_character(character_name, full_path, language)
# 自动探测参考音频配置
prompt_json_path = os.path.join(full_path, "prompt_wav.json")
ref_wav_path = os.path.join(full_path, "ref.wav")
if os.path.exists(prompt_json_path):
import json
with open(prompt_json_path, "r", encoding="utf-8") as f:
data = json.load(f)
config = data.get("default", {})
REF_CACHE[character_name] = {
"path": os.path.join(full_path, config.get("wav_path", "ref.wav")),
"text": config.get("prompt_text", ""),
"lang": config.get("prompt_lang", language)
}
print(f"📖 Loaded ref info from JSON for {character_name}")
elif os.path.exists(ref_wav_path):
REF_CACHE[character_name] = {
"path": ref_wav_path,
"text": "",
"lang": language
}
print(f"🎵 Found ref.wav for {character_name}")
return {"status": "success", "message": f"Character '{character_name}' loaded."}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/upload_and_tts")
async def upload_and_tts(
character_name: str = Form("Default"),
prompt_text: str = Form(...),
text: str = Form(...),
language: str = Form("zh"),
text_lang: str = Form(None),
speed: float = Form(1.0),
fragment_interval: float = Form(0.3),
fade_duration: float = Form(0.0), # 淡入淡出时长(秒)
file: UploadFile = File(...)
):
"""
上传临时参考音频并生成语音
"""
try:
# 🟢 确保模型已加载
if not genie_tts.model_manager.get(character_name):
print(f"⚠️ Character {character_name} not loaded, trying to load...")
char_path = os.path.join(MODELS_ROOT, character_name.lower())
if not os.path.exists(char_path):
char_path = os.path.join(MODELS_ROOT, "mzm") # 兜底逻辑
genie_tts.load_character(character_name, char_path, language)
ts = int(time.time() * 1000)
save_path = f"/tmp/ref_{ts}.wav"
os.makedirs("/tmp", exist_ok=True)
with open(save_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
print(f"🔥 [Custom] Using temp audio: {save_path}")
genie_tts.set_reference_audio(character_name, save_path, prompt_text, language)
out_path = f"/tmp/out_{ts}.wav"
# 🟢 执行 TTS
genie_tts.tts(character_name, text, save_path=out_path, play=False, text_language=text_lang, speed=speed, fragment_interval=fragment_interval, fade_duration=fade_duration)
# 🟢 关键:强制等待文件出现(最多等5秒)
wait_time = 0
while not os.path.exists(out_path) and wait_time < 50:
time.sleep(0.1)
wait_time += 1
if not os.path.exists(out_path):
raise HTTPException(status_code=500, detail="Audio file generation timed out or failed.")
def iterfile():
try:
with open(out_path, "rb") as f:
yield from f
finally:
# 🔴 修复:先清除 ReferenceAudio 缓存,再删除临时文件
# 否则 LRU 缓存会继续引用已删除的文件路径,导致后续请求报错
genie_tts.clear_reference_audio_cache()
time.sleep(0.5)
try:
if os.path.exists(save_path): os.remove(save_path)
if os.path.exists(out_path): os.remove(out_path)
except: pass
return StreamingResponse(iterfile(), media_type="audio/wav")
except Exception as e:
print(f"❌ Error in upload/tts: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/tts")
async def dynamic_tts(
text: str = Form(...),
character_name: str = Form("Base"),
prompt_text: str = Form(None),
prompt_lang: str = Form("zh"),
text_lang: str = Form(None),
speed: float = Form(1.0),
fragment_interval: float = Form(0.3),
fade_duration: float = Form(0.0), # 淡入淡出时长(秒)
use_default_ref: bool = Form(True)
):
"""
通用 TTS 接口,支持切换已加载的角色
text_lang: 目标文本语言,如果和参考音频不同则可实现跨语言合成
"""
try:
# 优先使用指定的角色,如果没有则尝试用 Base,如果都没有则报错
ref_info = REF_CACHE.get(character_name)
if not ref_info:
ref_info = REF_CACHE.get("Base")
if not ref_info:
raise HTTPException(status_code=404, detail=f"Character {character_name} not loaded and no Base model available.")
# 允许通过 API 动态覆盖当前参考文本(不换音频文件)
final_text = prompt_text if prompt_text else ref_info["text"]
genie_tts.set_reference_audio(character_name, ref_info["path"], final_text, prompt_lang)
out_path = f"/tmp/out_dyn_{int(time.time())}.wav"
genie_tts.tts(character_name, text, save_path=out_path, play=False, text_language=text_lang, speed=speed, fragment_interval=fragment_interval, fade_duration=fade_duration)
# 🟢 等待文件生成(最多等5秒)
wait_time = 0
while not os.path.exists(out_path) and wait_time < 50:
time.sleep(0.1)
wait_time += 1
# 🔴 修复:检查文件是否实际生成,避免返回不存在的文件
if not os.path.exists(out_path):
raise HTTPException(status_code=500, detail="TTS processing failed. Output file was not generated.")
return StreamingResponse(open(out_path, "rb"), media_type="audio/wav")
except HTTPException:
raise
except Exception as e:
print(f"❌ Error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/tts-stream")
async def tts_stream(
text: str,
character_name: str = "Base",
prompt_text: str = None,
prompt_lang: str = "zh",
text_lang: str = None,
speed: float = 1.0,
):
"""
流式 TTS 接口 - 边生成边返回 PCM 音频流
响应格式: audio/pcm (s16le, 32kHz, mono)
使用 ffplay 播放示例:
ffplay -f s16le -ar 32000 -ac 1 "http://xxx/tts-stream?text=你好世界"
"""
try:
# 获取参考音频配置
ref_info = REF_CACHE.get(character_name) or REF_CACHE.get("Base")
if not ref_info:
raise HTTPException(status_code=404, detail=f"Character {character_name} not found")
final_prompt_text = prompt_text if prompt_text else ref_info["text"]
genie_tts.set_reference_audio(character_name, ref_info["path"], final_prompt_text, prompt_lang)
async def audio_generator():
"""异步生成器:从 tts_async 获取音频块并流式返回"""
async for chunk in genie_tts.tts_async(
character_name=character_name,
text=text,
play=False,
split_sentence=True,
text_language=text_lang,
speed=speed,
):
yield chunk
return StreamingResponse(
audio_generator(),
media_type="audio/pcm",
headers={
"X-Audio-Sample-Rate": "32000",
"X-Audio-Channels": "1",
"X-Audio-Format": "s16le",
"Cache-Control": "no-cache",
}
)
except HTTPException:
raise
except Exception as e:
print(f"❌ Streaming Error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health():
return {"status": "ok", "models": list(REF_CACHE.keys())}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)
|