import os import torch import torchaudio import threading import gradio as gr import requests from fastapi import FastAPI from fastapi.responses import FileResponse from huggingface_hub import snapshot_download, hf_hub_download from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.models.xtts import Xtts import uvicorn # ===== MODEL SETUP ===== checkpoint_dir = "model/" repo_id = "capleaf/viXTTS" os.makedirs(checkpoint_dir, exist_ok=True) api_app = FastAPI() required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"] if not all(f in os.listdir(checkpoint_dir) for f in required_files): snapshot_download(repo_id=repo_id, local_dir=checkpoint_dir) hf_hub_download("coqui/XTTS-v2", "speakers_xtts.pth", local_dir=checkpoint_dir) config = XttsConfig() config.load_json(os.path.join(checkpoint_dir, "config.json")) MODEL = Xtts.init_from_config(config) MODEL.load_checkpoint(config, checkpoint_dir=checkpoint_dir, use_deepspeed=False) # CPU-only MODEL.cpu() MODEL.gpt.float() torch.set_num_threads(4) torch.backends.mkldnn.enabled = True LANGS = ["vi", "en", "zh-cn", "ja", "ko"] # ===== TTS FUNCTION ===== DEFAULT_REF = "model/samples/nu-luu-loat.wav" # file mẫu mặc định def tts_fn(text, language): ref_audio = DEFAULT_REF import os print(">>> Server-side ref_audio path:", os.path.abspath(ref_audio)) print(">>> Exists:", os.path.exists(ref_audio)) gpt_latent, spk_embed = MODEL.get_conditioning_latents( audio_path=ref_audio, gpt_cond_len=18, gpt_cond_chunk_len=4, max_ref_length=50 ) out = MODEL.inference( text=text, language=language, gpt_cond_latent=gpt_latent, speaker_embedding=spk_embed, temperature=0.65, repetition_penalty=2.5, enable_text_splitting=True ) wav = torch.tensor(out["wav"]).unsqueeze(0) out_path = "output.wav" torchaudio.save(out_path, wav, 24000) print(">>> Generated wav path:", os.path.abspath(out_path)) return out_path # ===== FASTAPI SERVER ===== @api_app.post("/api/speak") def speak_api(text: str, language: str = "vi"): try: path = tts_fn(text, language) return FileResponse(path, media_type="audio/wav") except Exception as e: return {"error": str(e)} # ===== GRADIO CLIENT ===== def gradio_client(text, language): try: r = requests.post( "http://127.0.0.1:8000/api/speak", params={"text": text, "language": language} ) if r.status_code == 200: with open("voice.wav", "wb") as f: f.write(r.content) return "voice.wav", "✅ Hoàn tất!" else: return None, f"❌ Lỗi API: {r.status_code}" except Exception as e: return None, f"❌ Lỗi: {str(e)}" # ===== GRADIO UI ===== with gr.Blocks(title="ViXTTS - Gradio + API") as demo: gr.Markdown("## 🎙️ Vietnamese TTS - CPU (Spaces HuggingFace)") with gr.Row(): with gr.Column(scale=1): text_in = gr.Textbox(label="Văn bản", value="Xin chào!", lines=4) lang_dd = gr.Dropdown(label="Ngôn ngữ", choices=LANGS, value="vi") btn = gr.Button("🎧 Tạo giọng") with gr.Column(scale=1): audio_out = gr.Audio(label="Kết quả", autoplay=True) info_out = gr.Textbox(label="Trạng thái", interactive=False) btn.click(gradio_client, inputs=[text_in, lang_dd], outputs=[audio_out, info_out]) # ===== CHẠY SONG SONG API + GRADIO ===== if __name__ == "__main__": threading.Thread(target=lambda: uvicorn.run(api_app, host="0.0.0.0", port=8000), daemon=True).start() demo.launch(server_name="0.0.0.0", server_port=7860)