nxhong commited on
Commit
f1134ba
·
verified ·
1 Parent(s): 968620b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -57
app.py CHANGED
@@ -1,16 +1,17 @@
1
  import os
2
  import torch
3
  import torchaudio
 
4
  import gradio as gr
 
5
  from fastapi import FastAPI
6
  from fastapi.responses import FileResponse
7
- import uvicorn
8
- import threading
9
  from huggingface_hub import snapshot_download, hf_hub_download
10
  from TTS.tts.configs.xtts_config import XttsConfig
11
  from TTS.tts.models.xtts import Xtts
 
12
 
13
- # ========== LOAD MODEL ==========
14
  checkpoint_dir = "model/"
15
  repo_id = "capleaf/viXTTS"
16
  os.makedirs(checkpoint_dir, exist_ok=True)
@@ -25,30 +26,19 @@ config.load_json(os.path.join(checkpoint_dir, "config.json"))
25
  MODEL = Xtts.init_from_config(config)
26
  MODEL.load_checkpoint(config, checkpoint_dir=checkpoint_dir, use_deepspeed=False)
27
 
28
- # CPU only
29
  MODEL.cpu()
30
  MODEL.gpt.float()
31
  torch.set_num_threads(4)
32
  torch.backends.mkldnn.enabled = True
33
 
34
- # Ngôn ngữ hỗ trợ
35
  LANGS = ["vi", "en", "zh-cn", "ja", "ko"]
36
 
37
- # ========== TTS FUNCTION ==========
38
- def predict(text, language, ref_audio):
39
- if not text.strip():
40
- return None, "⚠️ Nhập nội dung."
41
-
42
- if language not in LANGS:
43
- return None, f"❌ Ngôn ngữ '{language}' không được hỗ trợ."
44
-
45
  gpt_latent, spk_embed = MODEL.get_conditioning_latents(
46
- audio_path=ref_audio,
47
- gpt_cond_len=18,
48
- gpt_cond_chunk_len=4,
49
- max_ref_length=50
50
  )
51
-
52
  out = MODEL.inference(
53
  text=text,
54
  language=language,
@@ -58,57 +48,55 @@ def predict(text, language, ref_audio):
58
  repetition_penalty=2.5,
59
  enable_text_splitting=False
60
  )
61
-
62
  wav = torch.tensor(out["wav"]).unsqueeze(0)
63
  torchaudio.save("output.wav", wav, 24000)
64
- return "output.wav", "✅ Hoàn tất!"
65
 
66
- # ========== FASTAPI ==========
67
  api_app = FastAPI()
68
 
69
  @api_app.post("/api/speak")
70
- def speak_api(text: str = "Xin chào!", language: str = "vi"):
71
- ref_audio = "model/samples/nu-luu-loat.wav"
72
- audio_path, _ = predict(text, language, ref_audio)
73
- return FileResponse(audio_path, media_type="audio/wav")
74
-
75
- # ========== GRADIO UI ==========
76
- with gr.Blocks(title="🇻🇳 Vietnamese TTS - CPU") as demo:
77
- gr.Markdown("## 🎙️ Text to Speech (ViXTTS)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  with gr.Row():
80
  with gr.Column(scale=1):
81
- input_text = gr.Textbox(
82
- label="Văn bản",
83
- value="Xin chào! Tôi là mô hình tạo giọng nói tiếng Việt.",
84
- lines=4
85
- )
86
- lang_dd = gr.Dropdown(
87
- label="Ngôn ngữ",
88
- choices=LANGS,
89
- value="vi"
90
- )
91
- ref_audio = gr.Audio(
92
- label="Giọng mẫu (reference)",
93
- type="filepath",
94
- value="model/samples/nu-luu-loat.wav"
95
- )
96
- tts_button = gr.Button("🎙️ Tạo giọng", variant="primary")
97
 
98
  with gr.Column(scale=1):
99
- output_audio = gr.Audio(label="Kết quả", autoplay=True)
100
- output_info = gr.Textbox(label="Trạng thái", interactive=False)
101
 
102
- tts_button.click(
103
- predict,
104
- inputs=[input_text, lang_dd, ref_audio],
105
- outputs=[output_audio, output_info],
106
- )
107
 
108
- # ========== CHẠY SONG SONG API + GRADIO ==========
109
  if __name__ == "__main__":
110
- def run_api():
111
- uvicorn.run(api_app, host="0.0.0.0", port=8000)
112
-
113
- threading.Thread(target=run_api, daemon=True).start()
114
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import os
2
  import torch
3
  import torchaudio
4
+ import threading
5
  import gradio as gr
6
+ import requests
7
  from fastapi import FastAPI
8
  from fastapi.responses import FileResponse
 
 
9
  from huggingface_hub import snapshot_download, hf_hub_download
10
  from TTS.tts.configs.xtts_config import XttsConfig
11
  from TTS.tts.models.xtts import Xtts
12
+ import uvicorn
13
 
14
+ # ===== MODEL SETUP =====
15
  checkpoint_dir = "model/"
16
  repo_id = "capleaf/viXTTS"
17
  os.makedirs(checkpoint_dir, exist_ok=True)
 
26
  MODEL = Xtts.init_from_config(config)
27
  MODEL.load_checkpoint(config, checkpoint_dir=checkpoint_dir, use_deepspeed=False)
28
 
29
+ # CPU-only
30
  MODEL.cpu()
31
  MODEL.gpt.float()
32
  torch.set_num_threads(4)
33
  torch.backends.mkldnn.enabled = True
34
 
 
35
  LANGS = ["vi", "en", "zh-cn", "ja", "ko"]
36
 
37
+ # ===== TTS FUNCTION =====
38
+ def tts_fn(text, language, ref_audio):
 
 
 
 
 
 
39
  gpt_latent, spk_embed = MODEL.get_conditioning_latents(
40
+ audio_path=ref_audio, gpt_cond_len=18, gpt_cond_chunk_len=4, max_ref_length=50
 
 
 
41
  )
 
42
  out = MODEL.inference(
43
  text=text,
44
  language=language,
 
48
  repetition_penalty=2.5,
49
  enable_text_splitting=False
50
  )
 
51
  wav = torch.tensor(out["wav"]).unsqueeze(0)
52
  torchaudio.save("output.wav", wav, 24000)
53
+ return "output.wav"
54
 
55
+ # ===== FASTAPI SERVER =====
56
  api_app = FastAPI()
57
 
58
  @api_app.post("/api/speak")
59
+ def speak_api(text: str, language: str = "vi", ref_audio: str = "model/samples/nu-luu-loat.wav"):
60
+ try:
61
+ path = tts_fn(text, language, ref_audio)
62
+ return FileResponse(path, media_type="audio/wav")
63
+ except Exception as e:
64
+ return {"error": str(e)}
65
+
66
+ # ===== GRADIO CLIENT (gọi API nội bộ) =====
67
+ def gradio_client(text, language, ref_audio):
68
+ try:
69
+ r = requests.post(
70
+ "http://127.0.0.1:8000/api/speak",
71
+ params={"text": text, "language": language, "ref_audio": ref_audio}
72
+ )
73
+ if r.status_code == 200:
74
+ with open("voice.wav", "wb") as f:
75
+ f.write(r.content)
76
+ return "voice.wav", "✅ Hoàn tất!"
77
+ else:
78
+ return None, f"❌ Lỗi API: {r.status_code}"
79
+ except Exception as e:
80
+ return None, f"❌ Lỗi: {str(e)}"
81
+
82
+ # ===== GRADIO UI =====
83
+ with gr.Blocks(title="ViXTTS - Gradio + API") as demo:
84
+ gr.Markdown("## 🎙️ Vietnamese TTS - CPU (Spaces HuggingFace)")
85
 
86
  with gr.Row():
87
  with gr.Column(scale=1):
88
+ text_in = gr.Textbox(label="Văn bản", value="Xin chào!", lines=4)
89
+ lang_dd = gr.Dropdown(label="Ngôn ngữ", choices=LANGS, value="vi")
90
+ ref_audio = gr.Audio(label="Giọng mẫu", type="filepath", value="model/samples/nu-luu-loat.wav")
91
+ btn = gr.Button("🎧 Tạo giọng")
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  with gr.Column(scale=1):
94
+ audio_out = gr.Audio(label="Kết quả", autoplay=True)
95
+ info_out = gr.Textbox(label="Trạng thái", interactive=False)
96
 
97
+ btn.click(gradio_client, inputs=[text_in, lang_dd, ref_audio], outputs=[audio_out, info_out])
 
 
 
 
98
 
99
+ # ===== CHẠY SONG SONG API + GRADIO =====
100
  if __name__ == "__main__":
101
+ threading.Thread(target=lambda: uvicorn.run(api_app, host="0.0.0.0", port=8000), daemon=True).start()
 
 
 
102
  demo.launch(server_name="0.0.0.0", server_port=7860)