vixtts-api / app.py
nxhong's picture
Update app.py
01ae4f9 verified
import os
import torch
import torchaudio
import threading
import gradio as gr
import requests
from fastapi import FastAPI
from fastapi.responses import FileResponse
from huggingface_hub import snapshot_download, hf_hub_download
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
import uvicorn
# ===== MODEL SETUP =====
checkpoint_dir = "model/"
repo_id = "capleaf/viXTTS"
os.makedirs(checkpoint_dir, exist_ok=True)
api_app = FastAPI()
required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"]
if not all(f in os.listdir(checkpoint_dir) for f in required_files):
snapshot_download(repo_id=repo_id, local_dir=checkpoint_dir)
hf_hub_download("coqui/XTTS-v2", "speakers_xtts.pth", local_dir=checkpoint_dir)
config = XttsConfig()
config.load_json(os.path.join(checkpoint_dir, "config.json"))
MODEL = Xtts.init_from_config(config)
MODEL.load_checkpoint(config, checkpoint_dir=checkpoint_dir, use_deepspeed=False)
# CPU-only
MODEL.cpu()
MODEL.gpt.float()
torch.set_num_threads(4)
torch.backends.mkldnn.enabled = True
LANGS = ["vi", "en", "zh-cn", "ja", "ko"]
# ===== TTS FUNCTION =====
DEFAULT_REF = "model/samples/nu-luu-loat.wav" # file mẫu mặc định
def tts_fn(text, language):
ref_audio = DEFAULT_REF
import os
print(">>> Server-side ref_audio path:", os.path.abspath(ref_audio))
print(">>> Exists:", os.path.exists(ref_audio))
gpt_latent, spk_embed = MODEL.get_conditioning_latents(
audio_path=ref_audio,
gpt_cond_len=18,
gpt_cond_chunk_len=4,
max_ref_length=50
)
out = MODEL.inference(
text=text,
language=language,
gpt_cond_latent=gpt_latent,
speaker_embedding=spk_embed,
temperature=0.65,
repetition_penalty=2.5,
enable_text_splitting=True
)
wav = torch.tensor(out["wav"]).unsqueeze(0)
out_path = "output.wav"
torchaudio.save(out_path, wav, 24000)
print(">>> Generated wav path:", os.path.abspath(out_path))
return out_path
# ===== FASTAPI SERVER =====
@api_app.post("/api/speak")
def speak_api(text: str, language: str = "vi"):
try:
path = tts_fn(text, language)
return FileResponse(path, media_type="audio/wav")
except Exception as e:
return {"error": str(e)}
# ===== GRADIO CLIENT =====
def gradio_client(text, language):
try:
r = requests.post(
"http://127.0.0.1:8000/api/speak",
params={"text": text, "language": language}
)
if r.status_code == 200:
with open("voice.wav", "wb") as f:
f.write(r.content)
return "voice.wav", "✅ Hoàn tất!"
else:
return None, f"❌ Lỗi API: {r.status_code}"
except Exception as e:
return None, f"❌ Lỗi: {str(e)}"
# ===== GRADIO UI =====
with gr.Blocks(title="ViXTTS - Gradio + API") as demo:
gr.Markdown("## 🎙️ Vietnamese TTS - CPU (Spaces HuggingFace)")
with gr.Row():
with gr.Column(scale=1):
text_in = gr.Textbox(label="Văn bản", value="Xin chào!", lines=4)
lang_dd = gr.Dropdown(label="Ngôn ngữ", choices=LANGS, value="vi")
btn = gr.Button("🎧 Tạo giọng")
with gr.Column(scale=1):
audio_out = gr.Audio(label="Kết quả", autoplay=True)
info_out = gr.Textbox(label="Trạng thái", interactive=False)
btn.click(gradio_client, inputs=[text_in, lang_dd], outputs=[audio_out, info_out])
# ===== CHẠY SONG SONG API + GRADIO =====
if __name__ == "__main__":
threading.Thread(target=lambda: uvicorn.run(api_app, host="0.0.0.0", port=8000), daemon=True).start()
demo.launch(server_name="0.0.0.0", server_port=7860)