Spaces:
Runtime error
Runtime error
| import builtins | |
| import os | |
| import sys | |
| import shutil | |
| import io | |
| import time | |
| import uvicorn | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
| from fastapi.responses import StreamingResponse | |
| # 🔴 核心:在所有 import 之前,必须先劫持 input | |
| builtins.input = lambda prompt="": "y" | |
| # 使用本地 genie_tts 源码(而非已安装的包) | |
| # 将当前目录添加到 sys.path,确保优先加载本地模块 | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| if current_dir not in sys.path: | |
| sys.path.insert(0, current_dir) | |
| # 适配 Space 路径,本地运行时请确保此目录存在 | |
| os.environ["GENIE_DATA_DIR"] = "/app/GenieData" | |
| # 注释掉自动下载逻辑,假设 GenieData 已预装在镜像中 | |
| # if not os.path.exists("/app/GenieData/G2P"): | |
| # print("📦 Downloading GenieData Assets...") | |
| # from huggingface_hub import snapshot_download | |
| # snapshot_download(repo_id="High-Logic/Genie", allow_patterns=["GenieData/*"], local_dir="/app", local_dir_use_symlinks=False) | |
| import genie_tts | |
| app = FastAPI() | |
| # 角色模型存放根目录 | |
| MODELS_ROOT = "/app/models" | |
| os.makedirs(MODELS_ROOT, exist_ok=True) | |
| # 默认设置(加载 models/base 和 models/god) | |
| genie_tts.load_character("Base", "/app/models/base", "zh") | |
| genie_tts.load_character("god", "/app/models/god", "zh") | |
| # 记录每个角色的默认参考音频 | |
| REF_CACHE = { | |
| "Base": { | |
| "path": "/app/models/base/ref.wav", | |
| "text": "琴是个称职的好团长。看到她认真工作的样子,就连我也忍不住想要多帮她一把。", | |
| "lang": "zh" | |
| }, | |
| "god": { | |
| "path": "/app/models/god/ref.wav", | |
| "text": "很多人的一生,写于纸上也不过几行,大多都是些无聊的故事啊。", | |
| "lang": "zh" | |
| } | |
| } | |
| async def load_model(character_name: str = Form(...), model_path: str = Form(...), language: str = Form("zh")): | |
| """ | |
| 动态加载新模型 API | |
| model_path: 相对于 /app 的路径,例如 "models/my_character" | |
| """ | |
| full_path = os.path.join("/app", model_path) | |
| if not os.path.exists(full_path): | |
| raise HTTPException(status_code=404, detail=f"Model path not found: {full_path}") | |
| try: | |
| print(f"📦 Loading character: {character_name} from {full_path}") | |
| genie_tts.load_character(character_name, full_path, language) | |
| # 自动探测参考音频配置 | |
| prompt_json_path = os.path.join(full_path, "prompt_wav.json") | |
| ref_wav_path = os.path.join(full_path, "ref.wav") | |
| if os.path.exists(prompt_json_path): | |
| import json | |
| with open(prompt_json_path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| config = data.get("default", {}) | |
| REF_CACHE[character_name] = { | |
| "path": os.path.join(full_path, config.get("wav_path", "ref.wav")), | |
| "text": config.get("prompt_text", ""), | |
| "lang": config.get("prompt_lang", language) | |
| } | |
| print(f"📖 Loaded ref info from JSON for {character_name}") | |
| elif os.path.exists(ref_wav_path): | |
| REF_CACHE[character_name] = { | |
| "path": ref_wav_path, | |
| "text": "", | |
| "lang": language | |
| } | |
| print(f"🎵 Found ref.wav for {character_name}") | |
| return {"status": "success", "message": f"Character '{character_name}' loaded."} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def upload_and_tts( | |
| character_name: str = Form("Default"), | |
| prompt_text: str = Form(...), | |
| text: str = Form(...), | |
| language: str = Form("zh"), | |
| text_lang: str = Form(None), | |
| speed: float = Form(1.0), | |
| fragment_interval: float = Form(0.3), | |
| fade_duration: float = Form(0.0), # 淡入淡出时长(秒) | |
| file: UploadFile = File(...) | |
| ): | |
| """ | |
| 上传临时参考音频并生成语音 | |
| """ | |
| try: | |
| # 🟢 确保模型已加载 | |
| if not genie_tts.model_manager.get(character_name): | |
| print(f"⚠️ Character {character_name} not loaded, trying to load...") | |
| char_path = os.path.join(MODELS_ROOT, character_name.lower()) | |
| if not os.path.exists(char_path): | |
| char_path = os.path.join(MODELS_ROOT, "mzm") # 兜底逻辑 | |
| genie_tts.load_character(character_name, char_path, language) | |
| ts = int(time.time() * 1000) | |
| save_path = f"/tmp/ref_{ts}.wav" | |
| os.makedirs("/tmp", exist_ok=True) | |
| with open(save_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| print(f"🔥 [Custom] Using temp audio: {save_path}") | |
| genie_tts.set_reference_audio(character_name, save_path, prompt_text, language) | |
| out_path = f"/tmp/out_{ts}.wav" | |
| # 🟢 执行 TTS | |
| genie_tts.tts(character_name, text, save_path=out_path, play=False, text_language=text_lang, speed=speed, fragment_interval=fragment_interval, fade_duration=fade_duration) | |
| # 🟢 关键:强制等待文件出现(最多等5秒) | |
| wait_time = 0 | |
| while not os.path.exists(out_path) and wait_time < 50: | |
| time.sleep(0.1) | |
| wait_time += 1 | |
| if not os.path.exists(out_path): | |
| raise HTTPException(status_code=500, detail="Audio file generation timed out or failed.") | |
| def iterfile(): | |
| try: | |
| with open(out_path, "rb") as f: | |
| yield from f | |
| finally: | |
| # 🔴 修复:先清除 ReferenceAudio 缓存,再删除临时文件 | |
| # 否则 LRU 缓存会继续引用已删除的文件路径,导致后续请求报错 | |
| genie_tts.clear_reference_audio_cache() | |
| time.sleep(0.5) | |
| try: | |
| if os.path.exists(save_path): os.remove(save_path) | |
| if os.path.exists(out_path): os.remove(out_path) | |
| except: pass | |
| return StreamingResponse(iterfile(), media_type="audio/wav") | |
| except Exception as e: | |
| print(f"❌ Error in upload/tts: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def dynamic_tts( | |
| text: str = Form(...), | |
| character_name: str = Form("Base"), | |
| prompt_text: str = Form(None), | |
| prompt_lang: str = Form("zh"), | |
| text_lang: str = Form(None), | |
| speed: float = Form(1.0), | |
| fragment_interval: float = Form(0.3), | |
| fade_duration: float = Form(0.0), # 淡入淡出时长(秒) | |
| use_default_ref: bool = Form(True) | |
| ): | |
| """ | |
| 通用 TTS 接口,支持切换已加载的角色 | |
| text_lang: 目标文本语言,如果和参考音频不同则可实现跨语言合成 | |
| """ | |
| try: | |
| # 优先使用指定的角色,如果没有则尝试用 Base,如果都没有则报错 | |
| ref_info = REF_CACHE.get(character_name) | |
| if not ref_info: | |
| ref_info = REF_CACHE.get("Base") | |
| if not ref_info: | |
| raise HTTPException(status_code=404, detail=f"Character {character_name} not loaded and no Base model available.") | |
| # 允许通过 API 动态覆盖当前参考文本(不换音频文件) | |
| final_text = prompt_text if prompt_text else ref_info["text"] | |
| genie_tts.set_reference_audio(character_name, ref_info["path"], final_text, prompt_lang) | |
| out_path = f"/tmp/out_dyn_{int(time.time())}.wav" | |
| genie_tts.tts(character_name, text, save_path=out_path, play=False, text_language=text_lang, speed=speed, fragment_interval=fragment_interval, fade_duration=fade_duration) | |
| # 🟢 等待文件生成(最多等5秒) | |
| wait_time = 0 | |
| while not os.path.exists(out_path) and wait_time < 50: | |
| time.sleep(0.1) | |
| wait_time += 1 | |
| # 🔴 修复:检查文件是否实际生成,避免返回不存在的文件 | |
| if not os.path.exists(out_path): | |
| raise HTTPException(status_code=500, detail="TTS processing failed. Output file was not generated.") | |
| return StreamingResponse(open(out_path, "rb"), media_type="audio/wav") | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| print(f"❌ Error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def tts_stream( | |
| text: str, | |
| character_name: str = "Base", | |
| prompt_text: str = None, | |
| prompt_lang: str = "zh", | |
| text_lang: str = None, | |
| speed: float = 1.0, | |
| ): | |
| """ | |
| 流式 TTS 接口 - 边生成边返回 PCM 音频流 | |
| 响应格式: audio/pcm (s16le, 32kHz, mono) | |
| 使用 ffplay 播放示例: | |
| ffplay -f s16le -ar 32000 -ac 1 "http://xxx/tts-stream?text=你好世界" | |
| """ | |
| try: | |
| # 获取参考音频配置 | |
| ref_info = REF_CACHE.get(character_name) or REF_CACHE.get("Base") | |
| if not ref_info: | |
| raise HTTPException(status_code=404, detail=f"Character {character_name} not found") | |
| final_prompt_text = prompt_text if prompt_text else ref_info["text"] | |
| genie_tts.set_reference_audio(character_name, ref_info["path"], final_prompt_text, prompt_lang) | |
| async def audio_generator(): | |
| """异步生成器:从 tts_async 获取音频块并流式返回""" | |
| async for chunk in genie_tts.tts_async( | |
| character_name=character_name, | |
| text=text, | |
| play=False, | |
| split_sentence=True, | |
| text_language=text_lang, | |
| speed=speed, | |
| ): | |
| yield chunk | |
| return StreamingResponse( | |
| audio_generator(), | |
| media_type="audio/pcm", | |
| headers={ | |
| "X-Audio-Sample-Rate": "32000", | |
| "X-Audio-Channels": "1", | |
| "X-Audio-Format": "s16le", | |
| "Cache-Control": "no-cache", | |
| } | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| print(f"❌ Streaming Error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def health(): | |
| return {"status": "ok", "models": list(REF_CACHE.keys())} | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |