|
|
import builtins |
|
|
import os |
|
|
import shutil |
|
|
import io |
|
|
import time |
|
|
import uvicorn |
|
|
from fastapi import FastAPI, UploadFile, File, Form, HTTPException |
|
|
from fastapi.responses import StreamingResponse |
|
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
|
|
|
builtins.input = lambda prompt="": "y" |
|
|
|
|
|
os.environ["GENIE_DATA_DIR"] = "/app/GenieData" |
|
|
|
|
|
|
|
|
if not os.path.exists("/app/GenieData/G2P"): |
|
|
print("📦 Downloading GenieData Assets...") |
|
|
snapshot_download(repo_id="High-Logic/Genie", allow_patterns=["GenieData/*"], local_dir="/app", local_dir_use_symlinks=False) |
|
|
|
|
|
import genie_tts |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
MODELS_ROOT = "/app/models" |
|
|
os.makedirs(MODELS_ROOT, exist_ok=True) |
|
|
|
|
|
|
|
|
genie_tts.load_character("Base", "/app/models/base", "zh") |
|
|
genie_tts.load_character("god", "/app/models/god", "zh") |
|
|
|
|
|
|
|
|
REF_CACHE = { |
|
|
"Base": { |
|
|
"path": "/app/models/base/ref.wav", |
|
|
"text": "琴是个称职的好团长。看到她认真工作的样子,就连我也忍不住想要多帮她一把。", |
|
|
"lang": "zh" |
|
|
}, |
|
|
"god": { |
|
|
"path": "/app/models/god/ref.wav", |
|
|
"text": "很多人的一生,写于纸上也不过几行,大多都是些无聊的故事啊。", |
|
|
"lang": "zh" |
|
|
} |
|
|
} |
|
|
|
|
|
@app.post("/load_model") |
|
|
async def load_model(character_name: str = Form(...), model_path: str = Form(...), language: str = Form("zh")): |
|
|
""" |
|
|
动态加载新模型 API |
|
|
model_path: 相对于 /app 的路径,例如 "models/my_character" |
|
|
""" |
|
|
full_path = os.path.join("/app", model_path) |
|
|
if not os.path.exists(full_path): |
|
|
raise HTTPException(status_code=404, detail=f"Model path not found: {full_path}") |
|
|
|
|
|
try: |
|
|
print(f"📦 Loading character: {character_name} from {full_path}") |
|
|
genie_tts.load_character(character_name, full_path, language) |
|
|
return {"status": "success", "message": f"Character '{character_name}' loaded."} |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.post("/upload_and_tts") |
|
|
async def upload_and_tts( |
|
|
character_name: str = Form("Default"), |
|
|
prompt_text: str = Form(...), |
|
|
text: str = Form(...), |
|
|
language: str = Form("zh"), |
|
|
file: UploadFile = File(...) |
|
|
): |
|
|
""" |
|
|
上传临时参考音频并生成语音 |
|
|
""" |
|
|
try: |
|
|
ts = int(time.time() * 1000) |
|
|
save_path = f"/tmp/ref_{ts}.wav" |
|
|
os.makedirs("/tmp", exist_ok=True) |
|
|
|
|
|
with open(save_path, "wb") as buffer: |
|
|
shutil.copyfileobj(file.file, buffer) |
|
|
|
|
|
print(f"🔥 [Custom] Using temp audio for {character_name}: {save_path}") |
|
|
genie_tts.set_reference_audio(character_name, save_path, prompt_text, language) |
|
|
|
|
|
out_path = f"/tmp/out_{ts}.wav" |
|
|
genie_tts.tts(character_name, text, save_path=out_path, play=False) |
|
|
|
|
|
def iterfile(): |
|
|
with open(out_path, "rb") as f: |
|
|
yield from f |
|
|
try: |
|
|
os.remove(save_path) |
|
|
os.remove(out_path) |
|
|
except: pass |
|
|
|
|
|
return StreamingResponse(iterfile(), media_type="audio/wav") |
|
|
except Exception as e: |
|
|
print(f"❌ Error in upload/tts: {e}") |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.post("/tts") |
|
|
async def dynamic_tts( |
|
|
text: str = Form(...), |
|
|
character_name: str = Form("Default"), |
|
|
prompt_text: str = Form(None), |
|
|
prompt_lang: str = Form("zh"), |
|
|
use_default_ref: bool = Form(True) |
|
|
): |
|
|
""" |
|
|
通用 TTS 接口,支持切换已加载的角色 |
|
|
""" |
|
|
try: |
|
|
|
|
|
|
|
|
ref_info = REF_CACHE.get(character_name, REF_CACHE["Default"]) |
|
|
|
|
|
|
|
|
final_text = prompt_text if prompt_text else ref_info["text"] |
|
|
|
|
|
genie_tts.set_reference_audio(character_name, ref_info["path"], final_text, prompt_lang) |
|
|
|
|
|
out_path = f"/tmp/out_dyn_{int(time.time())}.wav" |
|
|
genie_tts.tts(character_name, text, save_path=out_path, play=False) |
|
|
|
|
|
return StreamingResponse(open(out_path, "rb"), media_type="audio/wav") |
|
|
except Exception as e: |
|
|
print(f"❌ Error: {e}") |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|