nxhong commited on
Commit
8538350
·
verified ·
1 Parent(s): f9affcb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -30
app.py CHANGED
@@ -2,74 +2,114 @@ import os
2
  import torch
3
  import torchaudio
4
  import gradio as gr
 
 
 
 
5
  from huggingface_hub import snapshot_download, hf_hub_download
6
  from TTS.tts.configs.xtts_config import XttsConfig
7
  from TTS.tts.models.xtts import Xtts
8
 
9
-
10
  # ========== LOAD MODEL ==========
11
  checkpoint_dir = "model/"
12
  repo_id = "capleaf/viXTTS"
13
  os.makedirs(checkpoint_dir, exist_ok=True)
14
 
15
- required = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"]
16
- if not all(x in os.listdir(checkpoint_dir) for x in required):
17
  snapshot_download(repo_id=repo_id, local_dir=checkpoint_dir)
18
  hf_hub_download("coqui/XTTS-v2", "speakers_xtts.pth", local_dir=checkpoint_dir)
19
 
20
  config = XttsConfig()
21
- config.load_json(f"{checkpoint_dir}/config.json")
22
  MODEL = Xtts.init_from_config(config)
23
  MODEL.load_checkpoint(config, checkpoint_dir=checkpoint_dir, use_deepspeed=False)
24
 
25
- # ✅ Force CPU
26
  MODEL.cpu()
27
  MODEL.gpt.float()
28
  MODEL.hifi_gan.float()
29
-
30
  torch.set_num_threads(4)
31
  torch.backends.mkldnn.enabled = True
32
 
 
 
33
 
34
  # ========== TTS FUNCTION ==========
35
- def predict(text, ref_audio):
36
- if not text:
37
- return None, "⚠️ Nhập nội dung đi."
 
 
 
38
 
39
  gpt_latent, spk_embed = MODEL.get_conditioning_latents(
40
  audio_path=ref_audio,
41
  gpt_cond_len=18,
42
  gpt_cond_chunk_len=4,
43
- max_ref_length=50,
44
  )
45
 
46
  out = MODEL.inference(
47
- text,
48
- "vi",
49
- gpt_latent,
50
- spk_embed,
51
- enable_text_splitting=False,
52
- temperature=0.7,
53
- repetition_penalty=3.0,
54
  )
55
 
56
  wav = torch.tensor(out["wav"]).unsqueeze(0)
57
  torchaudio.save("output.wav", wav, 24000)
58
- return "output.wav", "✅ Xong rồi!"
59
-
60
 
61
- # ========== GRADIO UI ==========
62
- with gr.Blocks() as demo:
63
- gr.Markdown("### 🇻🇳 ViXTTS - CPU Optimized (HuggingFace)")
64
-
65
- text_in = gr.Textbox(label="Văn bản", value="Xin chào! Đây là giọng nói tiếng Việt.")
66
- ref_in = gr.Audio(label="Giọng mẫu", type="filepath", value="model/samples/nu-luu-loat.wav")
67
- speak_btn = gr.Button("🎙️ Tạo giọng")
68
 
69
- audio_out = gr.Audio(label="Kết quả", autoplay=True)
70
- info_out = gr.Textbox(label="Trạng thái", interactive=False)
 
 
 
71
 
72
- speak_btn.click(predict, inputs=[text_in, ref_in], outputs=[audio_out, info_out])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
 
 
 
 
74
 
75
- demo.launch()
 
 
2
  import torch
3
  import torchaudio
4
  import gradio as gr
5
+ from fastapi import FastAPI
6
+ from fastapi.responses import FileResponse
7
+ import uvicorn
8
+ import threading
9
  from huggingface_hub import snapshot_download, hf_hub_download
10
  from TTS.tts.configs.xtts_config import XttsConfig
11
  from TTS.tts.models.xtts import Xtts
12
 
 
13
  # ========== LOAD MODEL ==========
14
  checkpoint_dir = "model/"
15
  repo_id = "capleaf/viXTTS"
16
  os.makedirs(checkpoint_dir, exist_ok=True)
17
 
18
+ required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"]
19
+ if not all(f in os.listdir(checkpoint_dir) for f in required_files):
20
  snapshot_download(repo_id=repo_id, local_dir=checkpoint_dir)
21
  hf_hub_download("coqui/XTTS-v2", "speakers_xtts.pth", local_dir=checkpoint_dir)
22
 
23
  config = XttsConfig()
24
+ config.load_json(os.path.join(checkpoint_dir, "config.json"))
25
  MODEL = Xtts.init_from_config(config)
26
  MODEL.load_checkpoint(config, checkpoint_dir=checkpoint_dir, use_deepspeed=False)
27
 
28
+ # ✅ CPU only
29
  MODEL.cpu()
30
  MODEL.gpt.float()
31
  MODEL.hifi_gan.float()
 
32
  torch.set_num_threads(4)
33
  torch.backends.mkldnn.enabled = True
34
 
35
+ # Ngôn ngữ hỗ trợ
36
+ LANGS = ["vi", "en", "zh-cn", "ja", "ko"]
37
 
38
  # ========== TTS FUNCTION ==========
39
+ def predict(text, language, ref_audio):
40
+ if not text.strip():
41
+ return None, "⚠️ Nhập nội dung."
42
+
43
+ if language not in LANGS:
44
+ return None, f"❌ Ngôn ngữ '{language}' không được hỗ trợ."
45
 
46
  gpt_latent, spk_embed = MODEL.get_conditioning_latents(
47
  audio_path=ref_audio,
48
  gpt_cond_len=18,
49
  gpt_cond_chunk_len=4,
50
+ max_ref_length=50
51
  )
52
 
53
  out = MODEL.inference(
54
+ text=text,
55
+ language=language,
56
+ gpt_cond_latent=gpt_latent,
57
+ speaker_embedding=spk_embed,
58
+ temperature=0.65,
59
+ repetition_penalty=2.5,
60
+ enable_text_splitting=False
61
  )
62
 
63
  wav = torch.tensor(out["wav"]).unsqueeze(0)
64
  torchaudio.save("output.wav", wav, 24000)
65
+ return "output.wav", "✅ Hoàn tất!"
 
66
 
67
+ # ========== FASTAPI ==========
68
+ api_app = FastAPI()
 
 
 
 
 
69
 
70
+ @api_app.post("/api/speak")
71
+ def speak_api(text: str = "Xin chào!", language: str = "vi"):
72
+ ref_audio = "model/samples/nu-luu-loat.wav"
73
+ audio_path, _ = predict(text, language, ref_audio)
74
+ return FileResponse(audio_path, media_type="audio/wav")
75
 
76
+ # ========== GRADIO UI ==========
77
+ with gr.Blocks(title="🇻🇳 Vietnamese TTS - CPU") as demo:
78
+ gr.Markdown("## 🎙️ Text to Speech (ViXTTS)")
79
+
80
+ with gr.Row():
81
+ with gr.Column(scale=1):
82
+ input_text = gr.Textbox(
83
+ label="Văn bản",
84
+ value="Xin chào! Tôi là mô hình tạo giọng nói tiếng Việt.",
85
+ lines=4
86
+ )
87
+ lang_dd = gr.Dropdown(
88
+ label="Ngôn ngữ",
89
+ choices=LANGS,
90
+ value="vi"
91
+ )
92
+ ref_audio = gr.Audio(
93
+ label="Giọng mẫu (reference)",
94
+ type="filepath",
95
+ value="model/samples/nu-luu-loat.wav"
96
+ )
97
+ tts_button = gr.Button("🎙️ Tạo giọng", variant="primary")
98
+
99
+ with gr.Column(scale=1):
100
+ output_audio = gr.Audio(label="Kết quả", autoplay=True)
101
+ output_info = gr.Textbox(label="Trạng thái", interactive=False)
102
+
103
+ tts_button.click(
104
+ predict,
105
+ inputs=[input_text, lang_dd, ref_audio],
106
+ outputs=[output_audio, output_info],
107
+ )
108
 
109
+ # ========== CHẠY SONG SONG API + GRADIO ==========
110
+ if __name__ == "__main__":
111
+ def run_api():
112
+ uvicorn.run(api_app, host="0.0.0.0", port=8000)
113
 
114
+ threading.Thread(target=run_api, daemon=True).start()
115
+ demo.launch(server_name="0.0.0.0", server_port=7860)