gexsaytts / app.py
GexSay's picture
Update app.py
cffa8ac verified
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import os
import tempfile
import pickle
from vinorm import TTSnorm
from f5_tts.model import DiT
from f5_tts.infer.utils_infer import load_vocoder, load_model, infer_process
from huggingface_hub import hf_hub_download, snapshot_download
import soundfile as sf
# Load models VÀ voice cùng lúc
hf_token = os.environ.get("HF_TOKEN")
print("🔄 Đang tải models và voice...")
# 1. Load TTS model
vocoder = load_vocoder()
model_ckpt = hf_hub_download(repo_id="GexSay/stt1beta", filename="model_last.pt", repo_type="model", token=hf_token)
vocab_file = hf_hub_download(repo_id="GexSay/stt1beta", filename="config.json", repo_type="model", token=hf_token)
model = load_model(DiT, dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4), ckpt_path=model_ckpt, vocab_file=vocab_file)
pkl_dict = {}
app = FastAPI(title="Bankme TTS API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
def post_process(text: str):
text = " " + text + " "
text = text.replace(" . . ", " . ").replace(" .. ", " . ")
text = text.replace(" , , ", " , ").replace(" ,, ", " , ")
text = text.replace('"', "")
return " ".join(text.split())
@app.get("/")
async def root():
return {"message": "Bankme TTS", "status": "running"}
@app.post("/tts")
async def generate_tts(voice: str, text: str, speed: float = 1.0):
try:
# Validate input
if not voice:
raise HTTPException(status_code=400, detail="Voice is required")
if not text.strip():
raise HTTPException(status_code=400, detail="Text is required")
if voice in pkl_dict:
pkl_path = pkl_dict[voice]
else:
print(f"🔄 Voice '{voice}' chưa có local, thử tải từ HF Hub...")
try:
pkl_path = hf_hub_download(
repo_id="GexSay/stt1beta",
filename=f"voice/{voice}.pkl",
repo_type="model",
token=hf_token
)
pkl_dict[voice] = pkl_path
print(f"✅ Đã tải voice '{voice}' thành công")
except Exception as e:
print(f"❌ Không thể tải voice '{voice}' từ HF Hub: {e}")
raise HTTPException(
status_code=404,
detail=f"Voice '{voice}' not found and cannot be downloaded. Available voice: {available_voice}"
)
# Load voice data từ pickle
with open(pkl_path, "rb") as f:
audio, sr, ref_text = pickle.load(f)
# Process text
processed_text = post_process(TTSnorm(text, punc=True)).lower()
# Generate audio
final_wave, final_sr, _ = infer_process(
audio, sr,
ref_text.lower(),
processed_text,
model,
vocoder,
nfe_step=8,
speed=speed
)
# Save to temporary file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
sf.write(tmp_file.name, final_wave, final_sr)
temp_path = tmp_file.name
return FileResponse(
temp_path,
media_type="audio/wav",
filename=f"tts_{voice}.wav"
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)