| from fastapi import FastAPI, HTTPException |
| from fastapi.responses import FileResponse |
| from fastapi.middleware.cors import CORSMiddleware |
| import uvicorn |
| import os |
| import tempfile |
| import pickle |
| from vinorm import TTSnorm |
| from f5_tts.model import DiT |
| from f5_tts.infer.utils_infer import load_vocoder, load_model, infer_process |
| from huggingface_hub import hf_hub_download, snapshot_download |
| import soundfile as sf |
|
|
| |
| hf_token = os.environ.get("HF_TOKEN") |
|
|
| print("🔄 Đang tải models và voice...") |
|
|
| |
| vocoder = load_vocoder() |
| model_ckpt = hf_hub_download(repo_id="GexSay/stt1beta", filename="model_last.pt", repo_type="model", token=hf_token) |
| vocab_file = hf_hub_download(repo_id="GexSay/stt1beta", filename="config.json", repo_type="model", token=hf_token) |
| model = load_model(DiT, dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4), ckpt_path=model_ckpt, vocab_file=vocab_file) |
|
|
| pkl_dict = {} |
|
|
| app = FastAPI(title="Bankme TTS API") |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| def post_process(text: str): |
| text = " " + text + " " |
| text = text.replace(" . . ", " . ").replace(" .. ", " . ") |
| text = text.replace(" , , ", " , ").replace(" ,, ", " , ") |
| text = text.replace('"', "") |
| return " ".join(text.split()) |
|
|
| @app.get("/") |
| async def root(): |
| return {"message": "Bankme TTS", "status": "running"} |
|
|
| @app.post("/tts") |
| async def generate_tts(voice: str, text: str, speed: float = 1.0): |
| try: |
| |
| if not voice: |
| raise HTTPException(status_code=400, detail="Voice is required") |
| if not text.strip(): |
| raise HTTPException(status_code=400, detail="Text is required") |
| |
| if voice in pkl_dict: |
| pkl_path = pkl_dict[voice] |
| else: |
| print(f"🔄 Voice '{voice}' chưa có local, thử tải từ HF Hub...") |
| try: |
| pkl_path = hf_hub_download( |
| repo_id="GexSay/stt1beta", |
| filename=f"voice/{voice}.pkl", |
| repo_type="model", |
| token=hf_token |
| ) |
| pkl_dict[voice] = pkl_path |
| print(f"✅ Đã tải voice '{voice}' thành công") |
| except Exception as e: |
| print(f"❌ Không thể tải voice '{voice}' từ HF Hub: {e}") |
| raise HTTPException( |
| status_code=404, |
| detail=f"Voice '{voice}' not found and cannot be downloaded. Available voice: {available_voice}" |
| ) |
|
|
| |
| with open(pkl_path, "rb") as f: |
| audio, sr, ref_text = pickle.load(f) |
| |
| |
| processed_text = post_process(TTSnorm(text, punc=True)).lower() |
| |
| |
| final_wave, final_sr, _ = infer_process( |
| audio, sr, |
| ref_text.lower(), |
| processed_text, |
| model, |
| vocoder, |
| nfe_step=8, |
| speed=speed |
| ) |
| |
| |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: |
| sf.write(tmp_file.name, final_wave, final_sr) |
| temp_path = tmp_file.name |
| |
| return FileResponse( |
| temp_path, |
| media_type="audio/wav", |
| filename=f"tts_{voice}.wav" |
| ) |
| |
| except HTTPException: |
| raise |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}") |
|
|
| if __name__ == "__main__": |
| uvicorn.run(app, host="0.0.0.0", port=7860) |