Nanny7's picture
feat: add streaming TTS endpoint /tts-stream
79f89ec
import builtins
import os
import sys
import shutil
import io
import time
import uvicorn
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import StreamingResponse
# 🔴 核心:在所有 import 之前,必须先劫持 input
builtins.input = lambda prompt="": "y"
# 使用本地 genie_tts 源码(而非已安装的包)
# 将当前目录添加到 sys.path,确保优先加载本地模块
current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir not in sys.path:
sys.path.insert(0, current_dir)
# 适配 Space 路径,本地运行时请确保此目录存在
os.environ["GENIE_DATA_DIR"] = "/app/GenieData"
# 注释掉自动下载逻辑,假设 GenieData 已预装在镜像中
# if not os.path.exists("/app/GenieData/G2P"):
# print("📦 Downloading GenieData Assets...")
# from huggingface_hub import snapshot_download
# snapshot_download(repo_id="High-Logic/Genie", allow_patterns=["GenieData/*"], local_dir="/app", local_dir_use_symlinks=False)
import genie_tts
app = FastAPI()
# 角色模型存放根目录
MODELS_ROOT = "/app/models"
os.makedirs(MODELS_ROOT, exist_ok=True)
# 默认设置(加载 models/base 和 models/god)
genie_tts.load_character("Base", "/app/models/base", "zh")
genie_tts.load_character("god", "/app/models/god", "zh")
# 记录每个角色的默认参考音频
REF_CACHE = {
"Base": {
"path": "/app/models/base/ref.wav",
"text": "琴是个称职的好团长。看到她认真工作的样子,就连我也忍不住想要多帮她一把。",
"lang": "zh"
},
"god": {
"path": "/app/models/god/ref.wav",
"text": "很多人的一生,写于纸上也不过几行,大多都是些无聊的故事啊。",
"lang": "zh"
}
}
@app.post("/load_model")
async def load_model(character_name: str = Form(...), model_path: str = Form(...), language: str = Form("zh")):
"""
动态加载新模型 API
model_path: 相对于 /app 的路径,例如 "models/my_character"
"""
full_path = os.path.join("/app", model_path)
if not os.path.exists(full_path):
raise HTTPException(status_code=404, detail=f"Model path not found: {full_path}")
try:
print(f"📦 Loading character: {character_name} from {full_path}")
genie_tts.load_character(character_name, full_path, language)
# 自动探测参考音频配置
prompt_json_path = os.path.join(full_path, "prompt_wav.json")
ref_wav_path = os.path.join(full_path, "ref.wav")
if os.path.exists(prompt_json_path):
import json
with open(prompt_json_path, "r", encoding="utf-8") as f:
data = json.load(f)
config = data.get("default", {})
REF_CACHE[character_name] = {
"path": os.path.join(full_path, config.get("wav_path", "ref.wav")),
"text": config.get("prompt_text", ""),
"lang": config.get("prompt_lang", language)
}
print(f"📖 Loaded ref info from JSON for {character_name}")
elif os.path.exists(ref_wav_path):
REF_CACHE[character_name] = {
"path": ref_wav_path,
"text": "",
"lang": language
}
print(f"🎵 Found ref.wav for {character_name}")
return {"status": "success", "message": f"Character '{character_name}' loaded."}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/upload_and_tts")
async def upload_and_tts(
character_name: str = Form("Default"),
prompt_text: str = Form(...),
text: str = Form(...),
language: str = Form("zh"),
text_lang: str = Form(None),
speed: float = Form(1.0),
fragment_interval: float = Form(0.3),
fade_duration: float = Form(0.0), # 淡入淡出时长(秒)
file: UploadFile = File(...)
):
"""
上传临时参考音频并生成语音
"""
try:
# 🟢 确保模型已加载
if not genie_tts.model_manager.get(character_name):
print(f"⚠️ Character {character_name} not loaded, trying to load...")
char_path = os.path.join(MODELS_ROOT, character_name.lower())
if not os.path.exists(char_path):
char_path = os.path.join(MODELS_ROOT, "mzm") # 兜底逻辑
genie_tts.load_character(character_name, char_path, language)
ts = int(time.time() * 1000)
save_path = f"/tmp/ref_{ts}.wav"
os.makedirs("/tmp", exist_ok=True)
with open(save_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
print(f"🔥 [Custom] Using temp audio: {save_path}")
genie_tts.set_reference_audio(character_name, save_path, prompt_text, language)
out_path = f"/tmp/out_{ts}.wav"
# 🟢 执行 TTS
genie_tts.tts(character_name, text, save_path=out_path, play=False, text_language=text_lang, speed=speed, fragment_interval=fragment_interval, fade_duration=fade_duration)
# 🟢 关键:强制等待文件出现(最多等5秒)
wait_time = 0
while not os.path.exists(out_path) and wait_time < 50:
time.sleep(0.1)
wait_time += 1
if not os.path.exists(out_path):
raise HTTPException(status_code=500, detail="Audio file generation timed out or failed.")
def iterfile():
try:
with open(out_path, "rb") as f:
yield from f
finally:
# 🔴 修复:先清除 ReferenceAudio 缓存,再删除临时文件
# 否则 LRU 缓存会继续引用已删除的文件路径,导致后续请求报错
genie_tts.clear_reference_audio_cache()
time.sleep(0.5)
try:
if os.path.exists(save_path): os.remove(save_path)
if os.path.exists(out_path): os.remove(out_path)
except: pass
return StreamingResponse(iterfile(), media_type="audio/wav")
except Exception as e:
print(f"❌ Error in upload/tts: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/tts")
async def dynamic_tts(
text: str = Form(...),
character_name: str = Form("Base"),
prompt_text: str = Form(None),
prompt_lang: str = Form("zh"),
text_lang: str = Form(None),
speed: float = Form(1.0),
fragment_interval: float = Form(0.3),
fade_duration: float = Form(0.0), # 淡入淡出时长(秒)
use_default_ref: bool = Form(True)
):
"""
通用 TTS 接口,支持切换已加载的角色
text_lang: 目标文本语言,如果和参考音频不同则可实现跨语言合成
"""
try:
# 优先使用指定的角色,如果没有则尝试用 Base,如果都没有则报错
ref_info = REF_CACHE.get(character_name)
if not ref_info:
ref_info = REF_CACHE.get("Base")
if not ref_info:
raise HTTPException(status_code=404, detail=f"Character {character_name} not loaded and no Base model available.")
# 允许通过 API 动态覆盖当前参考文本(不换音频文件)
final_text = prompt_text if prompt_text else ref_info["text"]
genie_tts.set_reference_audio(character_name, ref_info["path"], final_text, prompt_lang)
out_path = f"/tmp/out_dyn_{int(time.time())}.wav"
genie_tts.tts(character_name, text, save_path=out_path, play=False, text_language=text_lang, speed=speed, fragment_interval=fragment_interval, fade_duration=fade_duration)
# 🟢 等待文件生成(最多等5秒)
wait_time = 0
while not os.path.exists(out_path) and wait_time < 50:
time.sleep(0.1)
wait_time += 1
# 🔴 修复:检查文件是否实际生成,避免返回不存在的文件
if not os.path.exists(out_path):
raise HTTPException(status_code=500, detail="TTS processing failed. Output file was not generated.")
return StreamingResponse(open(out_path, "rb"), media_type="audio/wav")
except HTTPException:
raise
except Exception as e:
print(f"❌ Error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/tts-stream")
async def tts_stream(
text: str,
character_name: str = "Base",
prompt_text: str = None,
prompt_lang: str = "zh",
text_lang: str = None,
speed: float = 1.0,
):
"""
流式 TTS 接口 - 边生成边返回 PCM 音频流
响应格式: audio/pcm (s16le, 32kHz, mono)
使用 ffplay 播放示例:
ffplay -f s16le -ar 32000 -ac 1 "http://xxx/tts-stream?text=你好世界"
"""
try:
# 获取参考音频配置
ref_info = REF_CACHE.get(character_name) or REF_CACHE.get("Base")
if not ref_info:
raise HTTPException(status_code=404, detail=f"Character {character_name} not found")
final_prompt_text = prompt_text if prompt_text else ref_info["text"]
genie_tts.set_reference_audio(character_name, ref_info["path"], final_prompt_text, prompt_lang)
async def audio_generator():
"""异步生成器:从 tts_async 获取音频块并流式返回"""
async for chunk in genie_tts.tts_async(
character_name=character_name,
text=text,
play=False,
split_sentence=True,
text_language=text_lang,
speed=speed,
):
yield chunk
return StreamingResponse(
audio_generator(),
media_type="audio/pcm",
headers={
"X-Audio-Sample-Rate": "32000",
"X-Audio-Channels": "1",
"X-Audio-Format": "s16le",
"Cache-Control": "no-cache",
}
)
except HTTPException:
raise
except Exception as e:
print(f"❌ Streaming Error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health():
return {"status": "ok", "models": list(REF_CACHE.keys())}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)