from fastapi import FastAPI, HTTPException from fastapi.responses import FileResponse from fastapi.middleware.cors import CORSMiddleware import uvicorn import os import tempfile import pickle from vinorm import TTSnorm from f5_tts.model import DiT from f5_tts.infer.utils_infer import load_vocoder, load_model, infer_process from huggingface_hub import hf_hub_download, snapshot_download import soundfile as sf # Load models VÀ voice cùng lúc hf_token = os.environ.get("HF_TOKEN") print("🔄 Đang tải models và voice...") # 1. Load TTS model vocoder = load_vocoder() model_ckpt = hf_hub_download(repo_id="GexSay/stt1beta", filename="model_last.pt", repo_type="model", token=hf_token) vocab_file = hf_hub_download(repo_id="GexSay/stt1beta", filename="config.json", repo_type="model", token=hf_token) model = load_model(DiT, dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4), ckpt_path=model_ckpt, vocab_file=vocab_file) pkl_dict = {} app = FastAPI(title="Bankme TTS API") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) def post_process(text: str): text = " " + text + " " text = text.replace(" . . ", " . ").replace(" .. ", " . ") text = text.replace(" , , ", " , ").replace(" ,, ", " , ") text = text.replace('"', "") return " ".join(text.split()) @app.get("/") async def root(): return {"message": "Bankme TTS", "status": "running"} @app.post("/tts") async def generate_tts(voice: str, text: str, speed: float = 1.0): try: # Validate input if not voice: raise HTTPException(status_code=400, detail="Voice is required") if not text.strip(): raise HTTPException(status_code=400, detail="Text is required") if voice in pkl_dict: pkl_path = pkl_dict[voice] else: print(f"🔄 Voice '{voice}' chưa có local, thử tải từ HF Hub...") try: pkl_path = hf_hub_download( repo_id="GexSay/stt1beta", filename=f"voice/{voice}.pkl", repo_type="model", token=hf_token ) pkl_dict[voice] = pkl_path print(f"✅ Đã tải voice '{voice}' thành công") except Exception as e: print(f"❌ Không thể tải voice '{voice}' từ HF Hub: {e}") raise HTTPException( status_code=404, detail=f"Voice '{voice}' not found and cannot be downloaded. Available voice: {available_voice}" ) # Load voice data từ pickle with open(pkl_path, "rb") as f: audio, sr, ref_text = pickle.load(f) # Process text processed_text = post_process(TTSnorm(text, punc=True)).lower() # Generate audio final_wave, final_sr, _ = infer_process( audio, sr, ref_text.lower(), processed_text, model, vocoder, nfe_step=8, speed=speed ) # Save to temporary file with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: sf.write(tmp_file.name, final_wave, final_sr) temp_path = tmp_file.name return FileResponse( temp_path, media_type="audio/wav", filename=f"tts_{voice}.wav" ) except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)