nxhong commited on
Commit
40ff39d
·
verified ·
1 Parent(s): 38d1a97

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -157
app.py CHANGED
@@ -1,181 +1,75 @@
1
  import os
2
- import time
3
- import threading
4
  import torch
5
  import torchaudio
6
  import gradio as gr
7
- import spaces
8
- from fastapi import FastAPI
9
- from fastapi.responses import FileResponse
10
- import uvicorn
11
  from huggingface_hub import snapshot_download, hf_hub_download
12
  from TTS.tts.configs.xtts_config import XttsConfig
13
  from TTS.tts.models.xtts import Xtts
14
- from vinorm import TTSnorm
15
 
16
 
17
- # ========== SETUP MODEL ==========
18
- print("🔽 Đang tải mô hình capleaf/viXTTS...")
19
-
20
  checkpoint_dir = "model/"
21
  repo_id = "capleaf/viXTTS"
22
- use_deepspeed = False
23
-
24
  os.makedirs(checkpoint_dir, exist_ok=True)
25
- required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"]
26
- files_in_dir = os.listdir(checkpoint_dir)
27
- if not all(file in files_in_dir for file in required_files):
28
  snapshot_download(repo_id=repo_id, local_dir=checkpoint_dir)
29
- hf_hub_download(
30
- repo_id="coqui/XTTS-v2",
31
- filename="speakers_xtts.pth",
32
- local_dir=checkpoint_dir,
33
- )
34
 
35
- xtts_config = os.path.join(checkpoint_dir, "config.json")
36
  config = XttsConfig()
37
- config.load_json(xtts_config)
38
  MODEL = Xtts.init_from_config(config)
39
- MODEL.load_checkpoint(config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed)
40
- if torch.cuda.is_available():
41
- MODEL.cuda()
42
-
43
- supported_languages = config.languages
44
- if "vi" not in supported_languages:
45
- supported_languages.append("vi")
46
-
47
-
48
- # ========== UTILITIES ==========
49
- def normalize_vietnamese_text(text):
50
- text = (
51
- TTSnorm(text, unknown=False, lower=False, rule=True)
52
- .replace("..", ".")
53
- .replace("!.", "!")
54
- .replace("?.", "?")
55
- .replace(" .", ".")
56
- .replace(" ,", ",")
57
- .replace('"', "")
58
- .replace("'", "")
59
- .replace("AI", "Ây Ai")
60
- .replace("A.I", "Ây Ai")
61
- )
62
- return text
63
-
64
 
65
- def calculate_keep_len(text, lang):
66
- if lang in ["ja", "zh-cn"]:
67
- return -1
68
- word_count = len(text.split())
69
- num_punct = text.count(".") + text.count("!") + text.count("?") + text.count(",")
70
- if word_count < 5:
71
- return 15000 * word_count + 2000 * num_punct
72
- elif word_count < 10:
73
- return 13000 * word_count + 2000 * num_punct
74
- return -1
75
 
76
 
77
  # ========== TTS FUNCTION ==========
78
- @spaces.GPU
79
- def predict(text, language, ref_audio, normalize_text=True):
80
- if not text or len(text.strip()) == 0:
81
- return None, "⚠️ Vui lòng nhập nội dung văn bản."
82
-
83
- if language not in supported_languages:
84
- return None, f"❌ Ngôn ngữ '{language}' không được hỗ trợ."
85
-
86
- try:
87
- print(f"🎧 Đang sinh giọng nói [{language}] cho văn bản: {text[:50]}...")
88
-
89
- (gpt_cond_latent, speaker_embedding) = MODEL.get_conditioning_latents(
90
- audio_path=ref_audio,
91
- gpt_cond_len=30,
92
- gpt_cond_chunk_len=4,
93
- max_ref_length=60,
94
- )
95
-
96
- if normalize_text and language == "vi":
97
- text = normalize_vietnamese_text(text)
98
-
99
- t0 = time.time()
100
- out = MODEL.inference(
101
- text,
102
- language,
103
- gpt_cond_latent,
104
- speaker_embedding,
105
- repetition_penalty=5.0,
106
- temperature=0.75,
107
- enable_text_splitting=True,
108
- )
109
-
110
- inference_time = time.time() - t0
111
- rtf = (time.time() - t0) / out["wav"].shape[-1] * 24000
112
-
113
- keep_len = calculate_keep_len(text, language)
114
- out["wav"] = out["wav"][:keep_len]
115
- torchaudio.save("output.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
116
-
117
- info = f"⏱️ Thời gian sinh âm: {round(inference_time, 2)}s\n⚙️ RTF: {rtf:.2f}"
118
- return "output.wav", info
119
-
120
- except Exception as e:
121
- print("❌ Error:", str(e))
122
- return None, f"Lỗi khi sinh giọng nói: {str(e)}"
123
-
124
-
125
- # ========== FASTAPI ==========
126
- api_app = FastAPI()
127
-
128
-
129
- @api_app.post("/api/speak")
130
- def speak_api(text: str = "Xin chào!", language: str = "vi"):
131
- ref_audio = "model/samples/nu-luu-loat.wav"
132
- audio_path, _ = predict(text, language, ref_audio, True)
133
- return FileResponse(audio_path, media_type="audio/wav")
134
-
135
-
136
- # ========== GRADIO UI ==========
137
- with gr.Blocks(title="🇻🇳 Vietnamese TTS - capleaf/viXTTS") as demo:
138
- gr.Markdown("## 🎙️ Text to Speech (ViXTTS)")
139
- gr.Markdown("Nhập văn bản, chọn ngôn ngữ và giọng mẫu để tạo giọng nói.")
140
-
141
- with gr.Row():
142
- with gr.Column(scale=1):
143
- input_text = gr.Textbox(
144
- label="Văn bản cần đọc",
145
- value="Xin chào! Tôi là mô hình tạo giọng nói tiếng Việt.",
146
- lines=4,
147
- )
148
- lang_dd = gr.Dropdown(
149
- label="Ngôn ngữ",
150
- choices=["vi", "en", "zh-cn", "ja", "ko"],
151
- value="vi",
152
- )
153
- ref_audio = gr.Audio(
154
- label="Giọng mẫu (reference)",
155
- type="filepath",
156
- value="model/samples/nu-luu-loat.wav",
157
- )
158
- norm_cb = gr.Checkbox(label="Chuẩn hóa văn bản", value=True)
159
-
160
- # ✅ Đây là nút Predict
161
- tts_button = gr.Button("🎙️ Tạo giọng nói", variant="primary")
162
-
163
- with gr.Column(scale=1):
164
- output_audio = gr.Audio(label="Kết quả âm thanh", autoplay=True)
165
- output_info = gr.Textbox(label="Thông tin chi tiết", interactive=False)
166
-
167
- tts_button.click(
168
- predict,
169
- inputs=[input_text, lang_dd, ref_audio, norm_cb],
170
- outputs=[output_audio, output_info],
171
  )
172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
- # ========== CHẠY SONG SONG FASTAPI + GRADIO ==========
175
- if __name__ == "__main__":
176
- def run_api():
177
- uvicorn.run(api_app, host="0.0.0.0", port=8000)
178
 
179
- threading.Thread(target=run_api, daemon=True).start()
180
- demo.queue()
181
- demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
 
1
  import os
 
 
2
  import torch
3
  import torchaudio
4
  import gradio as gr
 
 
 
 
5
  from huggingface_hub import snapshot_download, hf_hub_download
6
  from TTS.tts.configs.xtts_config import XttsConfig
7
  from TTS.tts.models.xtts import Xtts
 
8
 
9
 
10
+ # ========== LOAD MODEL ==========
 
 
11
  checkpoint_dir = "model/"
12
  repo_id = "capleaf/viXTTS"
 
 
13
  os.makedirs(checkpoint_dir, exist_ok=True)
14
+
15
+ required = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"]
16
+ if not all(x in os.listdir(checkpoint_dir) for x in required):
17
  snapshot_download(repo_id=repo_id, local_dir=checkpoint_dir)
18
+ hf_hub_download("coqui/XTTS-v2", "speakers_xtts.pth", local_dir=checkpoint_dir)
 
 
 
 
19
 
 
20
  config = XttsConfig()
21
+ config.load_json(f"{checkpoint_dir}/config.json")
22
  MODEL = Xtts.init_from_config(config)
23
+ MODEL.load_checkpoint(config, checkpoint_dir=checkpoint_dir, use_deepspeed=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ # Force CPU + optimize for CPU inference
26
+ MODEL.cpu()
27
+ MODEL.model_gpt.float()
28
+ MODEL.vocoder.float()
29
+ torch.set_num_threads(4)
30
+ torch.backends.mkldnn.enabled = True
 
 
 
 
31
 
32
 
33
  # ========== TTS FUNCTION ==========
34
+ def predict(text, ref_audio):
35
+ if not text:
36
+ return None, "⚠️ Nhập nội dung đi."
37
+
38
+ # extract voice features
39
+ gpt_latent, spk_embed = MODEL.get_conditioning_latents(
40
+ audio_path=ref_audio,
41
+ gpt_cond_len=18, # ↓ giảm còn 18 → nhanh hơn ~30%
42
+ gpt_cond_chunk_len=4,
43
+ max_ref_length=50,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  )
45
 
46
+ out = MODEL.inference(
47
+ text,
48
+ "vi",
49
+ gpt_latent,
50
+ spk_embed,
51
+ enable_text_splitting=False, # ✅ chạy nhanh hơn
52
+ temperature=0.7,
53
+ repetition_penalty=3.0,
54
+ )
55
+
56
+ wav = torch.tensor(out["wav"]).unsqueeze(0)
57
+ torchaudio.save("output.wav", wav, 24000)
58
+ return "output.wav", "✅ Xong rồi!"
59
+
60
+
61
+ # ========== GRADIO UI (cũng là API) ==========
62
+ with gr.Blocks() as demo:
63
+ gr.Markdown("### 🇻🇳 ViXTTS - CPU Optimized (HuggingFace)")
64
+
65
+ text_in = gr.Textbox(label="Văn bản", value="Xin chào, đây là giọng nói tiếng Việt.")
66
+ ref_in = gr.Audio(label="Giọng mẫu", type="filepath", value="model/samples/nu-luu-loat.wav")
67
+ speak_btn = gr.Button("🎙️ Tạo giọng")
68
+
69
+ audio_out = gr.Audio(label="Kết quả", autoplay=True)
70
+ info_out = gr.Textbox(label="Trạng thái", interactive=False)
71
+
72
+ speak_btn.click(predict, inputs=[text_in, ref_in], outputs=[audio_out, info_out])
73
 
 
 
 
 
74
 
75
+ demo.launch()