pnnbao-ump commited on
Commit
3cdf1f1
·
verified ·
1 Parent(s): 5d4607a

Upload 46 files

Browse files
app.py CHANGED
@@ -1,318 +1,514 @@
1
- import spaces # PHẢI import TRƯỚC mọi thứ trên HF Spaces ZeroGPU
2
- import os
3
- os.environ['SPACES_ZERO_GPU'] = '1' # Set environment variable explicitly
4
-
5
- import gradio as gr
6
- import soundfile as sf
7
- import tempfile
8
- import torch
9
- from vieneu_tts import VieNeuTTS
10
- import time
11
-
12
- print("⏳ Đang khởi động VieNeu-TTS...")
13
-
14
- # --- 1. SETUP MODEL ---
15
- print("📦 Đang tải model...")
16
- device = "cuda" if torch.cuda.is_available() else "cpu"
17
- print(f"🖥️ Sử dụng thiết bị: {device.upper()}")
18
-
19
- try:
20
- tts = VieNeuTTS(
21
- backbone_repo="pnnbao-ump/VieNeu-TTS",
22
- backbone_device=device,
23
- codec_repo="neuphonic/neucodec",
24
- codec_device=device
25
- )
26
- print("✅ Model đã tải xong!")
27
- except Exception as e:
28
- print(f"⚠️ Không thể tải model (Chế độ UI Demo): {e}")
29
- class MockTTS:
30
- def encode_reference(self, path): return None
31
- def infer(self, text, ref, ref_text):
32
- import numpy as np
33
- # Giả lập độ trễ để test tính năng đo thời gian
34
- time.sleep(1.5)
35
- return np.random.uniform(-0.5, 0.5, 24000*3)
36
- tts = MockTTS()
37
-
38
- # --- 2. DATA ---
39
- VOICE_SAMPLES = {
40
- "Tuyên (nam miền Bắc)": {"audio": "./sample/Tuyên (nam miền Bắc).wav", "text": "./sample/Tuyên (nam miền Bắc).txt"},
41
- "Vĩnh (nam miền Nam)": {"audio": "./sample/Vĩnh (nam miền Nam).wav", "text": "./sample/Vĩnh (nam miền Nam).txt"},
42
- "Bình (nam miền Bắc)": {"audio": "./sample/Bình (nam miền Bắc).wav", "text": "./sample/Bình (nam miền Bắc).txt"},
43
- "Nguyên (nam miền Nam)": {"audio": "./sample/Nguyên (nam miền Nam).wav", "text": "./sample/Nguyên (nam miền Nam).txt"},
44
- "Sơn (nam miền Nam)": {"audio": "./sample/Sơn (nam miền Nam).wav", "text": "./sample/Sơn (nam miền Nam).txt"},
45
- "Đoan (nữ miền Nam)": {"audio": "./sample/Đoan (nữ miền Nam).wav", "text": "./sample/Đoan (nữ miền Nam).txt"},
46
- "Ngọc (nữ miền Bắc)": {"audio": "./sample/Ngọc (nữ miền Bắc).wav", "text": "./sample/Ngọc (nữ miền Bắc).txt"},
47
- "Ly (nữ miền Bắc)": {"audio": "./sample/Ly (nữ miền Bắc).wav", "text": "./sample/Ly (nữ miền Bắc).txt"},
48
- "Dung (nữ miền Nam)": {"audio": "./sample/Dung (nữ miền Nam).wav", "text": "./sample/Dung (nữ miền Nam).txt"}
49
- }
50
-
51
- # --- 3. HELPER FUNCTIONS ---
52
- def load_reference_info(voice_choice):
53
- if voice_choice in VOICE_SAMPLES:
54
- audio_path = VOICE_SAMPLES[voice_choice]["audio"]
55
- text_path = VOICE_SAMPLES[voice_choice]["text"]
56
- try:
57
- if os.path.exists(text_path):
58
- with open(text_path, "r", encoding="utf-8") as f:
59
- ref_text = f.read()
60
- return audio_path, ref_text
61
- else:
62
- return audio_path, "⚠️ Không tìm thấy file text mẫu."
63
- except Exception as e:
64
- return None, f"❌ Lỗi: {str(e)}"
65
- return None, ""
66
-
67
- @spaces.GPU(duration=120)
68
- def synthesize_speech(text, voice_choice, custom_audio, custom_text, mode_tab):
69
- try:
70
- if not text or text.strip() == "":
71
- return None, "⚠️ Vui lòng nhập văn bản cần tổng hợp!"
72
-
73
- # --- LOGIC CHECK LIMIT 250 ---
74
- if len(text) > 250:
75
- return None, f" Văn bản quá dài ({len(text)}/250 ký tự)! Vui lòng cắt ngắn lại để đảm bảo chất lượng."
76
-
77
- # Logic chọn Reference
78
- if mode_tab == "custom_mode":
79
- if custom_audio is None or not custom_text:
80
- return None, "⚠️ Vui lòng tải lên Audio và nhập nội dung Audio đó."
81
- ref_audio_path = custom_audio
82
- ref_text_raw = custom_text
83
- print("🎨 Mode: Custom Voice")
84
- else: # Preset
85
- if voice_choice not in VOICE_SAMPLES:
86
- return None, "⚠️ Vui lòng chọn một giọng mẫu."
87
- ref_audio_path = VOICE_SAMPLES[voice_choice]["audio"]
88
- ref_text_path = VOICE_SAMPLES[voice_choice]["text"]
89
-
90
- if not os.path.exists(ref_audio_path):
91
- return None, f"❌ Không tìm thấy file audio: {ref_audio_path}"
92
-
93
- with open(ref_text_path, "r", encoding="utf-8") as f:
94
- ref_text_raw = f.read()
95
- print(f"🎤 Mode: Preset Voice ({voice_choice})")
96
-
97
- # Inference & Đo thời gian
98
- print(f"📝 Text: {text[:50]}...")
99
-
100
- start_time = time.time() # <--- Bắt đầu bấm giờ
101
-
102
- ref_codes = tts.encode_reference(ref_audio_path)
103
- wav = tts.infer(text, ref_codes, ref_text_raw)
104
-
105
- end_time = time.time() # <--- Kết thúc bấm giờ
106
- process_time = end_time - start_time # <--- Tính thời gian xử lý
107
-
108
- # Save
109
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
110
- sf.write(tmp_file.name, wav, 24000)
111
- output_path = tmp_file.name
112
-
113
- # <--- Cập nhật thông báo kết quả
114
- return output_path, f" Thành công! (Mất {process_time:.2f} giây để tạo)"
115
-
116
- except Exception as e:
117
- import traceback
118
- traceback.print_exc()
119
- return None, f"❌ Lỗi hệ thống: {str(e)}"
120
-
121
- # --- 4. UI SETUP ---
122
- # SỬA LỖI Ở ĐÂY: Đổi gr.themes.Ocean -> gr.themes.Soft
123
- theme = gr.themes.Soft(
124
- primary_hue="indigo",
125
- secondary_hue="cyan",
126
- neutral_hue="slate",
127
- font=[gr.themes.GoogleFont('Inter'), 'ui-sans-serif', 'system-ui'],
128
- ).set(
129
- button_primary_background_fill="linear-gradient(90deg, #6366f1 0%, #0ea5e9 100%)",
130
- button_primary_background_fill_hover="linear-gradient(90deg, #4f46e5 0%, #0284c7 100%)",
131
- block_shadow="0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06)",
132
- )
133
-
134
- # <--- CSS ĐÃ SỬA (Background xanh đen + Chữ sáng)
135
- css = """
136
- .container { max-width: 1200px; margin: auto; }
137
- .header-box {
138
- text-align: center;
139
- margin-bottom: 25px;
140
- padding: 25px;
141
- background: linear-gradient(135deg, #0f172a 0%, #1e293b 100%); /* Xanh đen (Slate 900 -> 800) */
142
- border-radius: 12px;
143
- border: 1px solid #334155;
144
- box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.3);
145
- }
146
- .header-title {
147
- font-size: 2.5rem;
148
- font-weight: 800;
149
- color: white; /* Chữ trắng */
150
- background: -webkit-linear-gradient(45deg, #60A5FA, #22D3EE); /* Gradient xanh sáng cho chữ nổi bật */
151
- -webkit-background-clip: text;
152
- -webkit-text-fill-color: transparent;
153
- margin-bottom: 10px;
154
- }
155
- .header-desc {
156
- font-size: 1.1rem;
157
- color: #cbd5e1; /* Màu xám sáng (Slate-300) */
158
- margin-bottom: 15px;
159
- }
160
- .link-group a {
161
- text-decoration: none;
162
- margin: 0 10px;
163
- font-weight: 600;
164
- color: #94a3b8; /* Màu link sáng hơn chút */
165
- transition: color 0.2s;
166
- }
167
- .link-group a:hover { color: #38bdf8; text-shadow: 0 0 5px rgba(56, 189, 248, 0.5); }
168
-
169
- .status-box { font-weight: bold; text-align: center; border: none; background: transparent; }
170
- .warning-note {
171
- background-color: #fff7ed;
172
- border-left: 4px solid #f97316;
173
- padding: 12px;
174
- color: #9a3412;
175
- font-size: 0.9rem;
176
- border-radius: 4px;
177
- margin-top: 10px;
178
- margin-bottom: 10px;
179
- }
180
- """
181
-
182
- EXAMPLES_LIST = [
183
- # Nam Miền Nam
184
- ["Về miền Tây không chỉ để ngắm nhìn sông nước hữu tình, mà còn để cảm nhận tấm chân tình của người dân nơi đây. Cùng ngồi xuồng ba lá len lỏi qua rặng dừa nước, nghe câu vọng cổ ngọt ngào thì còn gì bằng.", "Vĩnh (nam miền Nam)"],
185
-
186
- # Nam Miền Bắc
187
- ["Hà Nội những ngày vào thu mang một vẻ đẹp trầm mặc và cổ kính đến lạ thường. Đi dạo quanh Hồ Gươm vào sáng sớm, hít hà mùi hoa sữa nồng nàn và thưởng thức chút cốm làng Vòng là trải nghiệm khó quên.", "Bình (nam miền Bắc)"],
188
-
189
- # Nam Miền Bắc
190
- ["Sự bùng nổ của trí tuệ nhân tạo đang định hình lại cách chúng ta làm việc và sinh sống. Từ xe tự lái đến trợ lý ảo thông minh, công nghệ đang dần xóa nhòa ranh giới giữa thực tại và những bộ phim viễn tưởng.", "Tuyên (nam miền Bắc)"],
191
-
192
- # Nam Miền Nam
193
- ["Sài Gòn hối hả là thế, nhưng chỉ cần tấp vào một quán cà phê ven đường, gọi ly bạc xỉu đá và ngắm nhìn dòng người qua lại, bạn sẽ thấy thành phố này cũng có những khoảng lặng thật bình yên và đáng yêu.", "Nguyên (nam miền Nam)"],
194
-
195
- # Nam Miền Nam
196
- ["Để đảm bảo tiến độ dự án quan trọng này, chúng ta cần tập trung tối đa nguồn lực và phối hợp chặt chẽ giữa các phòng ban. Mọi khó khăn phát sinh cần được báo cáo ngay lập tức để ban lãnh đạo xử kịp thời.", "Sơn (nam miền Nam)"],
197
-
198
- # Nữ Miền Nam
199
- ["Ngày xửa ngày xưa, một ngôi làng nọ có cô Tấm xinh đẹp, nết na nhưng sớm mồ côi mẹ. Dù bị mẹ kế và Cám hãm hại đủ đường, Tấm vẫn giữ được tấm lòng lương thiện và cuối cùng tìm được hạnh phúc xứng đáng.", "Đoan (nữ miền Nam)"],
200
-
201
- # Nữ Miền Bắc
202
- ["Dạ em chào anh chị, hiện tại bên em đang có chương trình ưu đãi đặc biệt cho căn hộ hướng sông này. Với thiết kế hiện đại và không gian xanh mát, đây chắc chắn là tổ ấm lý tưởng mà gia đình mình đang tìm kiếm.", "Ly (nữ miền Bắc)"],
203
-
204
- # Nữ Miền Bắc
205
- ["Dưới cơn mưa phùn lất phất của những ngày cuối đông, em khẽ nép vào vai anh, cảm nhận hơi ấm lan tỏa. Những khoảnh khắc bình dị như thế này khiến em nhận ra rằng, hạnh phúc đôi khi chỉ đơn giản l�� được ở bên nhau.", "Ngọc (nữ miền Bắc)"],
206
- ]
207
-
208
- with gr.Blocks(theme=theme, css=css, title="VieNeu-TTS Studio") as demo:
209
-
210
- with gr.Column(elem_classes="container"):
211
- # Header - Cập nhật class cho HTML
212
- gr.HTML("""
213
- <div class="header-box">
214
- <div class="header-title">🎙️ VieNeu-TTS Studio</div>
215
- <div class="header-desc">
216
- Phiên bản: VieNeu-TTS-1000h (model mới nhất, train trên 1000 giờ dữ liệu)
217
- </div>
218
- <div class="link-group">
219
- <a href="https://huggingface.co/pnnbao-ump/VieNeu-TTS" target="_blank">🤗 Model Card</a>
220
- <a href="https://huggingface.co/datasets/pnnbao-ump/VieNeu-TTS-1000h" target="_blank">📖 Dataset 1000h</a> •
221
- <a href="https://github.com/pnnbao97/VieNeu-TTS" target="_blank">🦜 GitHub</a>
222
- </div>
223
- </div>
224
- """)
225
-
226
- with gr.Row(elem_classes="container", equal_height=False):
227
-
228
- # --- LEFT: INPUT ---
229
- with gr.Column(scale=3, variant="panel"):
230
- gr.Markdown("### 📝 Văn bản đầu vào")
231
- text_input = gr.Textbox(
232
- label="Nhập văn bản",
233
- placeholder="Nhập nội dung tiếng Việt cần chuyển thành giọng nói...",
234
- lines=4,
235
- value="Sự bùng nổ của trí tuệ nhân tạo đang định hình lại cách chúng ta làm việc và sinh sống. Từ xe tự lái đến trợ lý ảo thông minh, công nghệ đang dần xóa nhòa ranh giới giữa thực tại và những bộ phim viễn tưởng.",
236
- show_label=False
237
- )
238
-
239
- # Counter + Warning
240
- with gr.Row():
241
- char_count = gr.HTML("<div style='text-align: right; color: #64748B; font-size: 0.8rem;'>0 / 250 ký tự</div>")
242
-
243
- gr.Markdown("### 🗣️ Chọn giọng đọc")
244
- with gr.Tabs() as tabs:
245
- with gr.TabItem("👤 Giọng có sẵn (Preset)", id="preset_mode"):
246
- voice_select = gr.Dropdown(
247
- choices=list(VOICE_SAMPLES.keys()),
248
- value="Tuyên (nam miền Bắc)",
249
- label="Danh sách giọng",
250
- interactive=True
251
- )
252
- with gr.Accordion("Thông tin giọng mẫu", open=False):
253
- ref_audio_preview = gr.Audio(label="Audio mẫu", interactive=False, type="filepath")
254
- ref_text_preview = gr.Markdown("...")
255
-
256
- with gr.TabItem("🎙️ Giọng tùy chỉnh (Custom)", id="custom_mode"):
257
- gr.Markdown("Tải lên giọng của bạn (Zero-shot Cloning)")
258
- custom_audio = gr.Audio(label="File ghi âm (.wav)", type="filepath")
259
- custom_text = gr.Textbox(label="Nội dung ghi âm", placeholder="Nhập chính xác lời thoại...")
260
-
261
- current_mode = gr.Textbox(visible=False, value="preset_mode")
262
- btn_generate = gr.Button("Tổng hợp giọng nói", variant="primary", size="lg")
263
-
264
- # --- RIGHT: OUTPUT ---
265
- with gr.Column(scale=2):
266
- gr.Markdown("### 🎧 Kết quả")
267
- with gr.Group():
268
- audio_output = gr.Audio(label="Audio đầu ra", type="filepath", show_download_button=True, autoplay=True)
269
- status_output = gr.Textbox(label="Trạng thái", show_label=False, elem_classes="status-box", placeholder="Sẵn sàng...")
270
-
271
- # --- EXAMPLES ---
272
- with gr.Row(elem_classes="container"):
273
- with gr.Column():
274
- gr.Markdown("### 📚 dụ mẫu")
275
- gr.Examples(examples=EXAMPLES_LIST, inputs=[text_input, voice_select], label="Thử nghiệm nhanh")
276
-
277
- # --- LOGIC ---
278
- def update_count(text):
279
- l = len(text)
280
- if l > 250:
281
- color = "#dc2626" # Red
282
- msg = f"⚠️ <b>{l} / 250</b> - Quá giới hạn!"
283
- elif l > 200:
284
- color = "#ea580c" # Orange
285
- msg = f"{l} / 250"
286
- else:
287
- color = "#64748B" # Gray
288
- msg = f"{l} / 250 ký tự"
289
- return f"<div style='text-align: right; color: {color}; font-size: 0.8rem; font-weight: bold'>{msg}</div>"
290
-
291
- text_input.change(update_count, text_input, char_count)
292
-
293
- def update_ref_preview(voice):
294
- audio, text = load_reference_info(voice)
295
- return audio, f"> *\"{text}\"*"
296
-
297
- voice_select.change(update_ref_preview, voice_select, [ref_audio_preview, ref_text_preview])
298
- demo.load(update_ref_preview, voice_select, [ref_audio_preview, ref_text_preview])
299
-
300
- # Tab handling - FIXED WITH *ARGS
301
- tab_preset = tabs.children[0]
302
- tab_custom = tabs.children[1]
303
-
304
- # Dùng *args để nhận bất kỳ số lượng tham số nào (0 hoặc 1), tránh lỗi Warning
305
- tab_preset.select(fn=lambda *args: "preset_mode", inputs=None, outputs=current_mode)
306
- tab_custom.select(fn=lambda *args: "custom_mode", inputs=None, outputs=current_mode)
307
-
308
- btn_generate.click(
309
- fn=synthesize_speech,
310
- inputs=[text_input, voice_select, custom_audio, custom_text, current_mode],
311
- outputs=[audio_output, status_output]
312
- )
313
-
314
- if __name__ == "__main__":
315
- demo.queue().launch(
316
- server_name="0.0.0.0",
317
- server_port=7860
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
  )
 
1
+ import spaces # PHẢI import TRƯỚC mọi thứ trên HF Spaces ZeroGPU
2
+ import os
3
+ os.environ['SPACES_ZERO_GPU'] = '1'
4
+
5
+ import gradio as gr
6
+ import soundfile as sf
7
+ import tempfile
8
+ import torch
9
+ from vieneu_tts import VieNeuTTS, FastVieNeuTTS
10
+ import time
11
+ import numpy as np
12
+ import yaml
13
+ from utils.core_utils import split_text_into_chunks
14
+ import queue
15
+ import threading
16
+
17
+ print(" Đang khởi động VieNeu-TTS...")
18
+
19
+ # --- LOAD CONFIG ---
20
+ CONFIG_PATH = os.path.join(os.path.dirname(__file__), "config.yaml")
21
+ try:
22
+ with open(CONFIG_PATH, "r", encoding="utf-8") as f:
23
+ _config = yaml.safe_load(f) or {}
24
+ except Exception as e:
25
+ raise RuntimeError(f"Không thể đọc config.yaml: {e}")
26
+
27
+ BACKBONE_CONFIGS = _config.get("backbone_configs", {})
28
+ CODEC_CONFIGS = _config.get("codec_configs", {})
29
+ VOICE_SAMPLES = _config.get("voice_samples", {})
30
+
31
+ _text_settings = _config.get("text_settings", {})
32
+ MAX_CHARS_PER_CHUNK = _text_settings.get("max_chars_per_chunk", 256)
33
+ MAX_TOTAL_CHARS_STREAMING = _text_settings.get("max_total_chars_streaming", 3000)
34
+
35
+ # --- KHỞI TẠO MODEL MẶC ĐỊNH ---
36
+ print("📦 Đang tải model...")
37
+ device = "cuda" if torch.cuda.is_available() else "cpu"
38
+ print(f"🖥️ Sử dụng thiết bị: {device.upper()}")
39
+
40
+ tts = None
41
+ using_fast_backend = False
42
+
43
+ try:
44
+ backbone_config = BACKBONE_CONFIGS["VieNeu-TTS (GPU)"]
45
+ codec_config = CODEC_CONFIGS["NeuCodec (Standard)"]
46
+
47
+ # Thử dùng FastVieNeuTTS nếu GPU đã cài LMDeploy
48
+ if device == "cuda":
49
+ try:
50
+ print("🚀 Thử tải FastVieNeuTTS (LMDeploy backend)...")
51
+ tts = FastVieNeuTTS(
52
+ backbone_repo=backbone_config["repo"],
53
+ backbone_device="cuda",
54
+ codec_repo=codec_config["repo"],
55
+ codec_device="cuda",
56
+ memory_util=0.3,
57
+ tp=1,
58
+ enable_prefix_caching=True,
59
+ quant_policy=8,
60
+ enable_triton=True,
61
+ max_batch_size=8,
62
+ )
63
+ using_fast_backend = True
64
+ print("✅ FastVieNeuTTS đã tải thành công!")
65
+
66
+ # Pre-cache voices
67
+ print("📝 Pre-caching voices...")
68
+ for voice_name, voice_info in VOICE_SAMPLES.items():
69
+ audio_path = voice_info["audio"]
70
+ text_path = voice_info["text"]
71
+ if os.path.exists(audio_path) and os.path.exists(text_path):
72
+ with open(text_path, "r", encoding="utf-8") as f:
73
+ ref_text = f.read()
74
+ tts.get_cached_reference(voice_name, audio_path, ref_text)
75
+ print(f" Cached {len(VOICE_SAMPLES)} voices")
76
+
77
+ except ImportError:
78
+ print("⚠️ LMDeploy không có, fallback về VieNeuTTS standard...")
79
+ using_fast_backend = False
80
+
81
+ # Fallback về standard VieNeuTTS
82
+ if tts is None:
83
+ print("📦 Đang tải VieNeuTTS (Standard backend)...")
84
+ tts = VieNeuTTS(
85
+ backbone_repo=backbone_config["repo"],
86
+ backbone_device=device,
87
+ codec_repo=codec_config["repo"],
88
+ codec_device=device
89
+ )
90
+ using_fast_backend = False
91
+
92
+ print("✅ Model đã tải xong!")
93
+
94
+ except Exception as e:
95
+ print(f"⚠️ Không thể tải model (Chế độ UI Demo): {e}")
96
+ class MockTTS:
97
+ def encode_reference(self, path): return None
98
+ def get_cached_reference(self, name, path, text): return None
99
+ def infer(self, text, ref, ref_text):
100
+ time.sleep(1.5)
101
+ return np.random.uniform(-0.5, 0.5, 24000*3)
102
+ def infer_batch(self, texts, ref, ref_text):
103
+ return [self.infer(t, ref, ref_text) for t in texts]
104
+ tts = MockTTS()
105
+ using_fast_backend = False
106
+
107
+ # --- HELPER FUNCTIONS ---
108
+ def load_reference_info(voice_choice):
109
+ if voice_choice in VOICE_SAMPLES:
110
+ audio_path = VOICE_SAMPLES[voice_choice]["audio"]
111
+ text_path = VOICE_SAMPLES[voice_choice]["text"]
112
+ try:
113
+ if os.path.exists(text_path):
114
+ with open(text_path, "r", encoding="utf-8") as f:
115
+ ref_text = f.read()
116
+ return audio_path, ref_text
117
+ else:
118
+ return audio_path, "⚠️ Không tìm thấy file text mẫu."
119
+ except Exception as e:
120
+ return None, f"❌ Lỗi: {str(e)}"
121
+ return None, ""
122
+
123
+ @spaces.GPU(duration=120)
124
+ def synthesize_speech(text, voice_choice, custom_audio, custom_text, mode_tab, generation_mode, use_batch):
125
+ """Tổng hợp giọng nói với GPU acceleration"""
126
+ global tts, using_fast_backend
127
+
128
+ if tts is None:
129
+ yield None, "⚠️ Model chưa được tải!"
130
+ return
131
+
132
+ if not text or text.strip() == "":
133
+ yield None, "⚠️ Vui lòng nhập văn bản!"
134
+ return
135
+
136
+ raw_text = text.strip()
137
+
138
+ # Setup Reference
139
+ if mode_tab == "custom_mode":
140
+ if custom_audio is None or not custom_text:
141
+ yield None, "⚠️ Thiếu Audio hoặc Text mẫu custom."
142
+ return
143
+ ref_audio_path = custom_audio
144
+ ref_text_raw = custom_text
145
+ use_cached = False
146
+ else:
147
+ if voice_choice not in VOICE_SAMPLES:
148
+ yield None, "⚠️ Vui lòng chọn giọng mẫu."
149
+ return
150
+ ref_audio_path = VOICE_SAMPLES[voice_choice]["audio"]
151
+ ref_text_path = VOICE_SAMPLES[voice_choice]["text"]
152
+
153
+ if not os.path.exists(ref_audio_path):
154
+ yield None, "❌ Không tìm thấy file audio mẫu."
155
+ return
156
+
157
+ with open(ref_text_path, "r", encoding="utf-8") as f:
158
+ ref_text_raw = f.read()
159
+ use_cached = True
160
+
161
+ yield None, "📄 Đang xử lý Reference..."
162
+
163
+ # Encode reference
164
+ try:
165
+ if use_cached and using_fast_backend and hasattr(tts, 'get_cached_reference'):
166
+ ref_codes = tts.get_cached_reference(voice_choice, ref_audio_path, ref_text_raw)
167
+ else:
168
+ ref_codes = tts.encode_reference(ref_audio_path)
169
+
170
+ if isinstance(ref_codes, torch.Tensor):
171
+ ref_codes = ref_codes.cpu().numpy()
172
+ except Exception as e:
173
+ yield None, f"❌ Lỗi xử lý reference: {e}"
174
+ return
175
+
176
+ # Split text
177
+ text_chunks = split_text_into_chunks(raw_text, max_chars=MAX_CHARS_PER_CHUNK)
178
+ total_chunks = len(text_chunks)
179
+
180
+ # === STANDARD MODE ===
181
+ if generation_mode == "Standard (Một lần)":
182
+ backend_name = "🚀 LMDeploy" if using_fast_backend else "📦 Standard"
183
+ batch_info = " (Batch Mode)" if use_batch and using_fast_backend and total_chunks > 1 else ""
184
+
185
+ yield None, f"{backend_name} Đang tổng hợp{batch_info} ({total_chunks} đoạn)..."
186
+
187
+ all_audio_segments = []
188
+ sr = 24000
189
+ silence_pad = np.zeros(int(sr * 0.15), dtype=np.float32)
190
+
191
+ start_time = time.time()
192
+
193
+ try:
194
+ # Batch processing nếu có FastVieNeuTTS
195
+ if use_batch and using_fast_backend and hasattr(tts, 'infer_batch') and total_chunks > 1:
196
+ yield None, f"⚡ Xửbatch ({total_chunks} đoạn cùng lúc)..."
197
+ chunk_wavs = tts.infer_batch(text_chunks, ref_codes, ref_text_raw)
198
+
199
+ for i, chunk_wav in enumerate(chunk_wavs):
200
+ if chunk_wav is not None and len(chunk_wav) > 0:
201
+ all_audio_segments.append(chunk_wav)
202
+ if i < total_chunks - 1:
203
+ all_audio_segments.append(silence_pad)
204
+ else:
205
+ # Sequential processing
206
+ for i, chunk in enumerate(text_chunks):
207
+ yield None, f"⏳ Đang xử lý đoạn {i+1}/{total_chunks}..."
208
+ chunk_wav = tts.infer(chunk, ref_codes, ref_text_raw)
209
+
210
+ if chunk_wav is not None and len(chunk_wav) > 0:
211
+ all_audio_segments.append(chunk_wav)
212
+ if i < total_chunks - 1:
213
+ all_audio_segments.append(silence_pad)
214
+
215
+ if not all_audio_segments:
216
+ yield None, "❌ Không sinh được audio nào."
217
+ return
218
+
219
+ yield None, "💾 Đang ghép file và lưu..."
220
+
221
+ final_wav = np.concatenate(all_audio_segments)
222
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
223
+ sf.write(tmp.name, final_wav, sr)
224
+ output_path = tmp.name
225
+
226
+ process_time = time.time() - start_time
227
+ speed_info = f", Tốc độ: {len(final_wav)/sr/process_time:.2f}x realtime" if process_time > 0 else ""
228
+
229
+ yield output_path, f"✅ Hoàn tất! (Thời gian: {process_time:.2f}s{speed_info}) {backend_name}"
230
+
231
+ # Cleanup memory
232
+ if using_fast_backend and hasattr(tts, 'cleanup_memory'):
233
+ tts.cleanup_memory()
234
+
235
+ except torch.cuda.OutOfMemoryError as e:
236
+ yield None, f"❌ GPU hết VRAM! Hãy thử giảm độ dài văn bản.\n\nChi tiết: {str(e)}"
237
+ if torch.cuda.is_available():
238
+ torch.cuda.empty_cache()
239
+ return
240
+
241
+ except Exception as e:
242
+ import traceback
243
+ traceback.print_exc()
244
+ yield None, f"❌ Lỗi: {str(e)}"
245
+ return
246
+
247
+ # === STREAMING MODE ===
248
+ else:
249
+ sr = 24000
250
+ crossfade_samples = int(sr * 0.03)
251
+ audio_queue = queue.Queue(maxsize=100)
252
+ PRE_BUFFER_SIZE = 3
253
+
254
+ end_event = threading.Event()
255
+ error_event = threading.Event()
256
+ error_msg = ""
257
+
258
+ def producer_thread():
259
+ nonlocal error_msg
260
+ try:
261
+ previous_tail = None
262
+
263
+ for i, chunk_text in enumerate(text_chunks):
264
+ stream_gen = tts.infer_stream(chunk_text, ref_codes, ref_text_raw)
265
+
266
+ for part_idx, audio_part in enumerate(stream_gen):
267
+ if audio_part is None or len(audio_part) == 0:
268
+ continue
269
+
270
+ if previous_tail is not None and len(previous_tail) > 0:
271
+ overlap = min(len(previous_tail), len(audio_part), crossfade_samples)
272
+ if overlap > 0:
273
+ fade_out = np.linspace(1.0, 0.0, overlap, dtype=np.float32)
274
+ fade_in = np.linspace(0.0, 1.0, overlap, dtype=np.float32)
275
+
276
+ blended = (audio_part[:overlap] * fade_in +
277
+ previous_tail[-overlap:] * fade_out)
278
+
279
+ processed = np.concatenate([
280
+ previous_tail[:-overlap] if len(previous_tail) > overlap else np.array([]),
281
+ blended,
282
+ audio_part[overlap:]
283
+ ])
284
+ else:
285
+ processed = np.concatenate([previous_tail, audio_part])
286
+
287
+ tail_size = min(crossfade_samples, len(processed))
288
+ previous_tail = processed[-tail_size:].copy()
289
+ output_chunk = processed[:-tail_size] if len(processed) > tail_size else processed
290
+ else:
291
+ tail_size = min(crossfade_samples, len(audio_part))
292
+ previous_tail = audio_part[-tail_size:].copy()
293
+ output_chunk = audio_part[:-tail_size] if len(audio_part) > tail_size else audio_part
294
+
295
+ if len(output_chunk) > 0:
296
+ audio_queue.put((sr, output_chunk))
297
+
298
+ if previous_tail is not None and len(previous_tail) > 0:
299
+ audio_queue.put((sr, previous_tail))
300
+
301
+ except Exception as e:
302
+ import traceback
303
+ traceback.print_exc()
304
+ error_msg = str(e)
305
+ error_event.set()
306
+ finally:
307
+ end_event.set()
308
+ audio_queue.put(None)
309
+
310
+ threading.Thread(target=producer_thread, daemon=True).start()
311
+
312
+ yield (sr, np.zeros(int(sr * 0.05))), "📄 Đang buffering..."
313
+
314
+ pre_buffer = []
315
+ while len(pre_buffer) < PRE_BUFFER_SIZE:
316
+ try:
317
+ item = audio_queue.get(timeout=5.0)
318
+ if item is None:
319
+ break
320
+ pre_buffer.append(item)
321
+ except queue.Empty:
322
+ if error_event.is_set():
323
+ yield None, f"❌ Lỗi: {error_msg}"
324
+ return
325
+ break
326
+
327
+ full_audio_buffer = []
328
+ backend_info = "🚀 LMDeploy" if using_fast_backend else "📦 Standard"
329
+ for sr, audio_data in pre_buffer:
330
+ full_audio_buffer.append(audio_data)
331
+ yield (sr, audio_data), f"🔊 Đang phát ({backend_info})..."
332
+
333
+ while True:
334
+ try:
335
+ item = audio_queue.get(timeout=0.05)
336
+ if item is None:
337
+ break
338
+ sr, audio_data = item
339
+ full_audio_buffer.append(audio_data)
340
+ yield (sr, audio_data), f"🔊 Đang phát ({backend_info})..."
341
+ except queue.Empty:
342
+ if error_event.is_set():
343
+ yield None, f"❌ Lỗi: {error_msg}"
344
+ break
345
+ if end_event.is_set() and audio_queue.empty():
346
+ break
347
+ continue
348
+
349
+ if full_audio_buffer:
350
+ final_wav = np.concatenate(full_audio_buffer)
351
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
352
+ sf.write(tmp.name, final_wav, sr)
353
+ yield tmp.name, f"✅ Hoàn tất Streaming! ({backend_info})"
354
+
355
+ if using_fast_backend and hasattr(tts, 'cleanup_memory'):
356
+ tts.cleanup_memory()
357
+
358
+ # --- UI SETUP ---
359
+ theme = gr.themes.Soft(
360
+ primary_hue="indigo",
361
+ secondary_hue="cyan",
362
+ neutral_hue="slate",
363
+ font=[gr.themes.GoogleFont('Inter'), 'ui-sans-serif', 'system-ui'],
364
+ ).set(
365
+ button_primary_background_fill="linear-gradient(90deg, #6366f1 0%, #0ea5e9 100%)",
366
+ button_primary_background_fill_hover="linear-gradient(90deg, #4f46e5 0%, #0284c7 100%)",
367
+ block_shadow="0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06)",
368
+ )
369
+
370
+ css = """
371
+ .container { max-width: 1200px; margin: auto; }
372
+ .header-box {
373
+ text-align: center;
374
+ margin-bottom: 25px;
375
+ padding: 25px;
376
+ background: linear-gradient(135deg, #0f172a 0%, #1e293b 100%);
377
+ border-radius: 12px;
378
+ border: 1px solid #334155;
379
+ box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.3);
380
+ }
381
+ .header-title {
382
+ font-size: 2.5rem;
383
+ font-weight: 800;
384
+ color: white;
385
+ background: -webkit-linear-gradient(45deg, #60A5FA, #22D3EE);
386
+ -webkit-background-clip: text;
387
+ -webkit-text-fill-color: transparent;
388
+ margin-bottom: 10px;
389
+ }
390
+ .header-desc {
391
+ font-size: 1.1rem;
392
+ color: #cbd5e1;
393
+ margin-bottom: 15px;
394
+ }
395
+ .link-group a {
396
+ text-decoration: none;
397
+ margin: 0 10px;
398
+ font-weight: 600;
399
+ color: #94a3b8;
400
+ transition: color 0.2s;
401
+ }
402
+ .link-group a:hover { color: #38bdf8; text-shadow: 0 0 5px rgba(56, 189, 248, 0.5); }
403
+ .status-box { font-weight: bold; text-align: center; border: none; background: transparent; }
404
+ """
405
+
406
+ EXAMPLES_LIST = [
407
+ ["Về miền Tây không chỉ để ngắm nhìn sông nước hữu tình, mà còn để cảm nhận tấm chân tình của người dân nơi đây. Cùng ngồi xuồng ba lá len lỏi qua rặng dừa nước, nghe câu vọng cổ ngọt ngào thì còn gì bằng.", "Vĩnh (nam miền Nam)"],
408
+ ["Hà Nội những ngày vào thu mang một vẻ đẹp trầm mặc và cổ kính đến lạ thường. Đi dạo quanh Hồ Gươm vào sáng sớm, hít hà mùi hoa sữa nồng nàn và thưởng thức chút cốm làng Vòng là trải nghiệm khó quên.", "Bình (nam miền Bắc)"],
409
+ ["Sự bùng nổ của trí tuệ nhân tạo đang định hình lại cách chúng ta làm việc và sinh sống. Từ xe tự lái đến trợ lý ảo thông minh, công nghệ đang dần xóa nhòa ranh giới giữa thực tại và những bộ phim viễn tưởng.", "Tuyên (nam miền Bắc)"],
410
+ ["Ngày xửa ngày xưa, ở một ngôi làng nọ có cô Tấm xinh đẹp, nết na nhưng sớm mồ côi mẹ. Dù bị mẹ kế và Cám hãm hại đủ đường, Tấm vẫn giữ được tấm lòng lương thiện và cuối cùng tìm được hạnh phúc xứng đáng.", "Đoan (nữ miền Nam)"],
411
+ ]
412
+
413
+ with gr.Blocks(theme=theme, css=css, title="VieNeu-TTS Studio") as demo:
414
+ with gr.Column(elem_classes="container"):
415
+ # Header
416
+ gr.HTML("""
417
+ <div class="header-box">
418
+ <div class="header-title">🎙️ VieNeu-TTS Studio</div>
419
+ <div class="header-desc">
420
+ Phiên bản: VieNeu-TTS-1000h (GPU-optimized với LMDeploy)
421
+ </div>
422
+ <div class="link-group">
423
+ <a href="https://huggingface.co/pnnbao-ump/VieNeu-TTS" target="_blank">🤗 Model Card</a> •
424
+ <a href="https://huggingface.co/datasets/pnnbao-ump/VieNeu-TTS-1000h" target="_blank">📖 Dataset 1000h</a> •
425
+ <a href="https://github.com/pnnbao97/VieNeu-TTS" target="_blank">🦜 GitHub</a>
426
+ </div>
427
+ </div>
428
+ """)
429
+
430
+ # Status info
431
+ backend_status = "🚀 LMDeploy (GPU-optimized)" if using_fast_backend else "📦 Standard Backend"
432
+ gr.Markdown(f"**Backend hiện tại:** {backend_status}")
433
+
434
+ with gr.Row(elem_classes="container", equal_height=False):
435
+ # --- LEFT: INPUT ---
436
+ with gr.Column(scale=3, variant="panel"):
437
+ gr.Markdown("### 📝 Văn bản đầu vào")
438
+ text_input = gr.Textbox(
439
+ label="Nhập văn bản",
440
+ placeholder="Nhập nội dung tiếng Việt cần chuyển thành giọng nói...",
441
+ lines=4,
442
+ value="Sự bùng nổ của trí tuệ nhân tạo đang định hình lại cách chúng ta làm việc và sinh sống. Từ xe tự lái đến trợ lý ảo thông minh, công nghệ đang dần xóa nhòa ranh giới giữa thực tại và những bộ phim viễn tưởng.",
443
+ show_label=False
444
+ )
445
+
446
+ gr.Markdown("### 🗣️ Chọn giọng đọc")
447
+ with gr.Tabs() as tabs:
448
+ with gr.TabItem("👤 Giọng có sẵn (Preset)", id="preset_mode"):
449
+ voice_select = gr.Dropdown(
450
+ choices=list(VOICE_SAMPLES.keys()),
451
+ value=list(VOICE_SAMPLES.keys())[0] if VOICE_SAMPLES else None,
452
+ label="Danh sách giọng",
453
+ interactive=True
454
+ )
455
+ with gr.Accordion("Thông tin giọng mẫu", open=False):
456
+ ref_audio_preview = gr.Audio(label="Audio mẫu", interactive=False, type="filepath")
457
+ ref_text_preview = gr.Markdown("...")
458
+
459
+ with gr.TabItem("🎙️ Giọng tùy chỉnh (Custom)", id="custom_mode"):
460
+ gr.Markdown("Tải lên giọng của bạn (Zero-shot Cloning)")
461
+ custom_audio = gr.Audio(label="File ghi âm (.wav)", type="filepath")
462
+ custom_text = gr.Textbox(label="Nội dung ghi âm", placeholder="Nhập chính xác lời thoại...")
463
+
464
+ gr.Markdown("### ⚙️ Cài đặt tổng hợp")
465
+ generation_mode = gr.Radio(
466
+ ["Standard (Một lần)", "Streaming (Thời gian thực)"],
467
+ value="Standard (Một lần)",
468
+ label="Chế độ sinh"
469
+ )
470
+ use_batch = gr.Checkbox(
471
+ value=True,
472
+ label="⚡ Batch Processing (chỉ có hiệu lực khi dùng LMDeploy backend)",
473
+ info="Xử lý nhiều đoạn cùng lúc để tăng tốc"
474
+ )
475
+
476
+ current_mode = gr.Textbox(visible=False, value="preset_mode")
477
+ btn_generate = gr.Button("Tổng hợp giọng nói", variant="primary", size="lg")
478
+
479
+ # --- RIGHT: OUTPUT ---
480
+ with gr.Column(scale=2):
481
+ gr.Markdown("### 🎧 Kết quả")
482
+ with gr.Group():
483
+ audio_output = gr.Audio(label="Audio đầu ra", type="filepath", show_download_button=True, autoplay=True)
484
+ status_output = gr.Textbox(label="Trạng thái", show_label=False, elem_classes="status-box", placeholder="Sẵn sàng...")
485
+
486
+ # --- EXAMPLES ---
487
+ with gr.Row(elem_classes="container"):
488
+ with gr.Column():
489
+ gr.Markdown("### 📚 Ví dụ mẫu")
490
+ gr.Examples(examples=EXAMPLES_LIST, inputs=[text_input, voice_select], label="Thử nghiệm nhanh")
491
+
492
+ # --- EVENT HANDLERS ---
493
+ def update_ref_preview(voice):
494
+ audio, text = load_reference_info(voice)
495
+ return audio, f"> *\"{text}\"*"
496
+
497
+ voice_select.change(update_ref_preview, voice_select, [ref_audio_preview, ref_text_preview])
498
+ demo.load(update_ref_preview, voice_select, [ref_audio_preview, ref_text_preview])
499
+
500
+ # Tab handling
501
+ tabs.children[0].select(fn=lambda: "preset_mode", outputs=current_mode)
502
+ tabs.children[1].select(fn=lambda: "custom_mode", outputs=current_mode)
503
+
504
+ btn_generate.click(
505
+ fn=synthesize_speech,
506
+ inputs=[text_input, voice_select, custom_audio, custom_text, current_mode, generation_mode, use_batch],
507
+ outputs=[audio_output, status_output]
508
+ )
509
+
510
+ if __name__ == "__main__":
511
+ demo.queue().launch(
512
+ server_name="0.0.0.0",
513
+ server_port=7860
514
  )
requirements.txt CHANGED
@@ -1,9 +1,9 @@
1
- gradio
2
- spaces
3
- torchaudio
4
- transformers
5
- librosa
6
- soundfile
7
- numpy
8
- phonemizer
9
- neucodec
 
1
+ torchaudio
2
+ transformers
3
+ librosa
4
+ soundfile
5
+ numpy
6
+ phonemizer
7
+ neucodec
8
+ lmdeploy
9
+ pyyaml
sample/Bình (nam miền Bắc).pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f896d618fc46c3e131eda7b4168e25e9c2fb2d7ea0e864bedff2577fbd0bd30
3
+ size 2089
sample/Dung (nữ miền Nam).pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc4d65b6504470cb00e46763915060590595fbe4d47912eeacecd2bf1bade262
3
+ size 2153
sample/Hương (nữ miền Bắc).pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:919035b7c762956a7d568cebc6e69fea22eb9be02bf906c1d32c1db1d8c7b9ff
3
+ size 2217
sample/Ly (nữ miền Bắc).pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:69b6bc9bb1062122dc3755be907d87f232fa8be5129b54f6994dead35f4935c6
3
+ size 2153
sample/Nguyên (nam miền Nam).pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e6ebaa0b2977589afa7e7f811b0553151bd8312c96a70b1b666bd9d0fd50edf
3
+ size 2345
sample/Ngọc (nữ miền Bắc).pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78ab670f177092dc8586e45536faea20fdb84471dc8d8a8b1b95dd76a4ed3d0d
3
+ size 2281
sample/Sơn (nam miền Nam).pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:114cb04ee2357d06de2f038853bbeb0dc57fc8ed30e085118a9e0bf5a70f7857
3
+ size 2281
sample/Tuyên (nam miền Bắc).pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e79eb6ee9cc7cd35cb4fbbef107249ed3209608b59644c52f55a34941a531873
3
+ size 2473
sample/Vĩnh (nam miền Nam).pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c87342d3a6a8cbaaf2139c21e7554eea19aba6aa03248e4426238a1c2507e447
3
+ size 2217
sample/Đoan (nữ miền Nam).pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28b48dbae193adc88aa26243086ba3ce862def7035d9793613c2967df29f9afe
3
+ size 2793
utils/__pycache__/__init__.cpython-312.pyc CHANGED
Binary files a/utils/__pycache__/__init__.cpython-312.pyc and b/utils/__pycache__/__init__.cpython-312.pyc differ
 
utils/__pycache__/core_utils.cpython-312.pyc ADDED
Binary file (1.9 kB). View file
 
utils/__pycache__/normalize_text.cpython-312.pyc CHANGED
Binary files a/utils/__pycache__/normalize_text.cpython-312.pyc and b/utils/__pycache__/normalize_text.cpython-312.pyc differ
 
utils/__pycache__/phonemize_text.cpython-312.pyc CHANGED
Binary files a/utils/__pycache__/phonemize_text.cpython-312.pyc and b/utils/__pycache__/phonemize_text.cpython-312.pyc differ
 
utils/core_utils.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import List
3
+
4
+ def split_text_into_chunks(text: str, max_chars: int = 256) -> List[str]:
5
+ """
6
+ Split raw text into chunks no longer than max_chars.
7
+ Preference is given to sentence boundaries; otherwise falls back to word-based splitting.
8
+ """
9
+ sentences = re.split(r"(?<=[\.\!\?\…])\s+", text.strip())
10
+ chunks: List[str] = []
11
+ buffer = ""
12
+
13
+ def flush_buffer():
14
+ nonlocal buffer
15
+ if buffer:
16
+ chunks.append(buffer.strip())
17
+ buffer = ""
18
+
19
+ for sentence in sentences:
20
+ sentence = sentence.strip()
21
+ if not sentence:
22
+ continue
23
+
24
+ if len(sentence) <= max_chars:
25
+ candidate = f"{buffer} {sentence}".strip() if buffer else sentence
26
+ if len(candidate) <= max_chars:
27
+ buffer = candidate
28
+ else:
29
+ flush_buffer()
30
+ buffer = sentence
31
+ continue
32
+
33
+ flush_buffer()
34
+ words = sentence.split()
35
+ current = ""
36
+ for word in words:
37
+ candidate = f"{current} {word}".strip() if current else word
38
+ if len(candidate) > max_chars and current:
39
+ chunks.append(current.strip())
40
+ current = word
41
+ else:
42
+ current = candidate
43
+ if current:
44
+ chunks.append(current.strip())
45
+
46
+ flush_buffer()
47
+ return [chunk for chunk in chunks if chunk]
utils/normalize_text.py CHANGED
@@ -1,408 +1,408 @@
1
- import re
2
-
3
- class VietnameseTTSNormalizer:
4
- """
5
- A text normalizer for Vietnamese Text-to-Speech systems.
6
- Converts numbers, dates, units, and special characters into readable Vietnamese text.
7
- """
8
-
9
- def __init__(self):
10
- self.units = {
11
- 'km': 'ki lô mét', 'dm': 'đê xi mét', 'cm': 'xen ti mét',
12
- 'mm': 'mi li mét', 'nm': 'na nô mét', 'µm': 'mic rô mét',
13
- 'μm': 'mic rô mét', 'm': 'mét',
14
-
15
- 'kg': 'ki lô gam', 'g': 'gam', 'mg': 'mi li gam',
16
-
17
- 'km²': 'ki lô mét vuông', 'km2': 'ki lô mét vuông',
18
- 'm²': 'mét vuông', 'm2': 'mét vuông',
19
- 'cm²': 'xen ti mét vuông', 'cm2': 'xen ti mét vuông',
20
- 'mm²': 'mi li mét vuông', 'mm2': 'mi li mét vuông',
21
- 'ha': 'héc ta',
22
-
23
- 'km³': 'ki lô mét khối', 'km3': 'ki lô mét khối',
24
- 'm³': 'mét khối', 'm3': 'mét khối',
25
- 'cm³': 'xen ti mét khối', 'cm3': 'xen ti mét khối',
26
- 'mm³': 'mi li mét khối', 'mm3': 'mi li mét khối',
27
- 'l': 'lít', 'dl': 'đê xi lít', 'ml': 'mi li lít', 'hl': 'héc tô lít',
28
-
29
- 'v': 'vôn', 'kv': 'ki lô vôn', 'mv': 'mi li vôn',
30
- 'a': 'am pe', 'ma': 'mi li am pe', 'ka': 'ki lô am pe',
31
- 'w': 'oát', 'kw': 'ki lô oát', 'mw': 'mê ga oát', 'gw': 'gi ga oát',
32
- 'kwh': 'ki lô oát giờ', 'mwh': 'mê ga oát giờ', 'wh': 'oát giờ',
33
- 'ω': 'ôm', 'ohm': 'ôm', 'kω': 'ki lô ôm', 'mω': 'mê ga ôm',
34
-
35
- 'hz': 'héc', 'khz': 'ki lô héc', 'mhz': 'mê ga héc', 'ghz': 'gi ga héc',
36
-
37
- 'pa': 'pát cal', 'kpa': 'ki lô pát cal', 'mpa': 'mê ga pát cal',
38
- 'bar': 'ba', 'mbar': 'mi li ba', 'atm': 'át mốt phia', 'psi': 'pi ét xai',
39
-
40
- 'j': 'giun', 'kj': 'ki lô giun',
41
- 'cal': 'ca lo', 'kcal': 'ki lô ca lo',
42
- }
43
-
44
- self.digits = ['không', 'một', 'hai', 'ba', 'bốn',
45
- 'năm', 'sáu', 'bảy', 'tám', 'chín']
46
-
47
- def normalize(self, text):
48
- """Main normalization pipeline."""
49
- text = text.lower()
50
- text = self._normalize_temperature(text)
51
- text = self._normalize_currency(text)
52
- text = self._normalize_percentage(text)
53
- text = self._normalize_units(text)
54
- text = self._normalize_time(text)
55
- text = self._normalize_date(text)
56
- text = self._normalize_phone(text)
57
- text = self._normalize_numbers(text)
58
- text = self._number_to_words(text)
59
- text = self._normalize_special_chars(text)
60
- text = self._normalize_whitespace(text)
61
- return text
62
-
63
- def _normalize_temperature(self, text):
64
- """Convert temperature notation to words."""
65
- text = re.sub(r'-(\d+(?:[.,]\d+)?)\s*°\s*c\b', r'âm \1 độ xê', text, flags=re.IGNORECASE)
66
- text = re.sub(r'-(\d+(?:[.,]\d+)?)\s*°\s*f\b', r'âm \1 độ ép', text, flags=re.IGNORECASE)
67
- text = re.sub(r'(\d+(?:[.,]\d+)?)\s*°\s*c\b', r'\1 độ xê', text, flags=re.IGNORECASE)
68
- text = re.sub(r'(\d+(?:[.,]\d+)?)\s*°\s*f\b', r'\1 độ ép', text, flags=re.IGNORECASE)
69
- text = re.sub(r'°', ' độ ', text)
70
- return text
71
-
72
- def _normalize_currency(self, text):
73
- """Convert currency notation to words."""
74
- def decimal_currency(match):
75
- whole = match.group(1)
76
- decimal = match.group(2)
77
- unit = match.group(3)
78
- decimal_words = ' '.join([self.digits[int(d)] for d in decimal])
79
- unit_map = {'k': 'nghìn', 'm': 'triệu', 'b': 'tỷ'}
80
- unit_word = unit_map.get(unit.lower(), unit)
81
- return f"{whole} phẩy {decimal_words} {unit_word}"
82
-
83
- text = re.sub(r'(\d+)[.,](\d+)\s*([kmb])\b', decimal_currency, text, flags=re.IGNORECASE)
84
- text = re.sub(r'(\d+)\s*k\b', r'\1 nghìn', text, flags=re.IGNORECASE)
85
- text = re.sub(r'(\d+)\s*m\b', r'\1 triệu', text, flags=re.IGNORECASE)
86
- text = re.sub(r'(\d+)\s*b\b', r'\1 tỷ', text, flags=re.IGNORECASE)
87
- text = re.sub(r'(\d+(?:[.,]\d+)?)\s*đ\b', r'\1 đồng', text)
88
- text = re.sub(r'(\d+(?:[.,]\d+)?)\s*vnd\b', r'\1 đồng', text, flags=re.IGNORECASE)
89
- text = re.sub(r'\$\s*(\d+(?:[.,]\d+)?)', r'\1 đô la', text)
90
- text = re.sub(r'(\d+(?:[.,]\d+)?)\s*\$', r'\1 đô la', text)
91
- return text
92
-
93
- def _normalize_percentage(self, text):
94
- """Convert percentage to words."""
95
- text = re.sub(r'(\d+(?:[.,]\d+)?)\s*%', r'\1 phần trăm', text)
96
- return text
97
-
98
- def _normalize_units(self, text):
99
- """Convert measurement units to words."""
100
- def expand_compound_with_number(match):
101
- number = match.group(1)
102
- unit1 = match.group(2).lower()
103
- unit2 = match.group(3).lower()
104
- full_unit1 = self.units.get(unit1, unit1)
105
- full_unit2 = self.units.get(unit2, unit2)
106
- return f"{number} {full_unit1} trên {full_unit2}"
107
-
108
- def expand_compound_without_number(match):
109
- unit1 = match.group(1).lower()
110
- unit2 = match.group(2).lower()
111
- full_unit1 = self.units.get(unit1, unit1)
112
- full_unit2 = self.units.get(unit2, unit2)
113
- return f"{full_unit1} trên {full_unit2}"
114
-
115
- text = re.sub(r'(\d+(?:[.,]\d+)?)\s*([a-zA-Zμµ²³°]+)/([a-zA-Zμµ²³°0-9]+)\b',
116
- expand_compound_with_number, text)
117
- text = re.sub(r'\b([a-zA-Zμµ²³°]+)/([a-zA-Zμµ²³°0-9]+)\b',
118
- expand_compound_without_number, text)
119
-
120
- sorted_units = sorted(self.units.items(), key=lambda x: len(x[0]), reverse=True)
121
- for unit, full_name in sorted_units:
122
- pattern = r'(\d+(?:[.,]\d+)?)\s*' + re.escape(unit) + r'\b'
123
- text = re.sub(pattern, rf'\1 {full_name}', text, flags=re.IGNORECASE)
124
-
125
- for unit, full_name in sorted_units:
126
- if any(c in unit for c in '²³°'):
127
- pattern = r'\b' + re.escape(unit) + r'\b'
128
- text = re.sub(pattern, full_name, text, flags=re.IGNORECASE)
129
-
130
- return text
131
-
132
- def _normalize_time(self, text):
133
- """Convert time notation to words with validation."""
134
-
135
- def validate_and_convert_time(match):
136
- """Validate time components before converting."""
137
- groups = match.groups()
138
-
139
- # HH:MM:SS format
140
- if len(groups) == 3:
141
- hour, minute, second = groups
142
- hour_int, minute_int, second_int = int(hour), int(minute), int(second)
143
-
144
- # Validate ranges
145
- if not (0 <= hour_int <= 23):
146
- return match.group(0) # Return original if invalid
147
- if not (0 <= minute_int <= 59):
148
- return match.group(0)
149
- if not (0 <= second_int <= 59):
150
- return match.group(0)
151
-
152
- return f"{hour} giờ {minute} phút {second} giây"
153
-
154
- # HH:MM or HHhMM format
155
- elif len(groups) == 2:
156
- hour, minute = groups
157
- hour_int, minute_int = int(hour), int(minute)
158
-
159
- # Validate ranges
160
- if not (0 <= hour_int <= 23):
161
- return match.group(0)
162
- if not (0 <= minute_int <= 59):
163
- return match.group(0)
164
-
165
- return f"{hour} giờ {minute} phút"
166
-
167
- # HHh format
168
- else:
169
- hour = groups[0]
170
- hour_int = int(hour)
171
-
172
- if not (0 <= hour_int <= 23):
173
- return match.group(0)
174
-
175
- return f"{hour} giờ"
176
-
177
- # Apply patterns with validation
178
- text = re.sub(r'(\d{1,2}):(\d{2}):(\d{2})', validate_and_convert_time, text)
179
- text = re.sub(r'(\d{1,2}):(\d{2})', validate_and_convert_time, text)
180
- text = re.sub(r'(\d{1,2})h(\d{2})', validate_and_convert_time, text)
181
- text = re.sub(r'(\d{1,2})h\b', validate_and_convert_time, text)
182
-
183
- return text
184
-
185
- def _normalize_date(self, text):
186
- """Convert date notation to words with validation."""
187
-
188
- def is_valid_date(day, month, year):
189
- """Check if date components are valid."""
190
- day, month, year = int(day), int(month), int(year)
191
-
192
- # Basic range checks
193
- if not (1 <= day <= 31):
194
- return False
195
- if not (1 <= month <= 12):
196
- return False
197
-
198
- return True
199
-
200
- def date_to_text(match):
201
- day, month, year = match.groups()
202
- if is_valid_date(day, month, year):
203
- return f"ngày {day} tháng {month} năm {year}"
204
- return match.group(0) # Return original if invalid
205
-
206
- def date_iso_to_text(match):
207
- year, month, day = match.groups()
208
- if is_valid_date(day, month, year):
209
- return f"ngày {day} tháng {month} năm {year}"
210
- return match.group(0)
211
-
212
- def date_short_year(match):
213
- day, month, year = match.groups()
214
- full_year = f"20{year}" if int(year) < 50 else f"19{year}"
215
- if is_valid_date(day, month, full_year):
216
- return f"ngày {day} tháng {month} năm {full_year}"
217
- return match.group(0)
218
-
219
- # Apply patterns with validation
220
- text = re.sub(r'\bngày\s+(\d{1,2})[/\-](\d{1,2})[/\-](\d{4})\b',
221
- lambda m: date_to_text(m).replace('ngày ngày', 'ngày'), text)
222
- text = re.sub(r'\bngày\s+(\d{1,2})[/\-](\d{1,2})[/\-](\d{2})\b',
223
- lambda m: date_short_year(m).replace('ngày ngày', 'ngày'), text)
224
- text = re.sub(r'\b(\d{4})-(\d{1,2})-(\d{1,2})\b', date_iso_to_text, text)
225
- text = re.sub(r'\b(\d{1,2})[/\-](\d{1,2})[/\-](\d{4})\b', date_to_text, text)
226
- text = re.sub(r'\b(\d{1,2})[/\-](\d{1,2})[/\-](\d{2})\b', date_short_year, text)
227
-
228
- return text
229
-
230
- def _normalize_phone(self, text):
231
- """Convert phone numbers to digit-by-digit reading."""
232
- def phone_to_text(match):
233
- phone = match.group(0)
234
- phone = re.sub(r'[^\d]', '', phone)
235
-
236
- if phone.startswith('84') and len(phone) >= 10:
237
- phone = '0' + phone[2:]
238
-
239
- if 10 <= len(phone) <= 11:
240
- words = [self.digits[int(d)] for d in phone]
241
- return ' '.join(words) + ' '
242
-
243
- return match.group(0)
244
-
245
- text = re.sub(r'(\+84|84)[\s\-\.]?\d[\d\s\-\.]{7,}', phone_to_text, text)
246
- text = re.sub(r'\b0\d[\d\s\-\.]{8,}', phone_to_text, text)
247
- return text
248
-
249
- def _normalize_numbers(self, text):
250
- text = re.sub(r'(\d+(?:[,.]\d+)?)%', lambda m: f'{m.group(1)} phần trăm', text)
251
- # 1. Xóa dấu thousand separator trước
252
- text = re.sub(r'(\d{1,3})(?:\.(\d{3}))+', lambda m: m.group(0).replace('.', ''), text)
253
-
254
- # 2. Chuyển số thập phân thành chữ
255
- def decimal_to_words(match):
256
- whole = match.group(1)
257
- decimal = match.group(2)
258
- decimal_words = ' '.join([self.digits[int(d)] for d in decimal])
259
- separator = 'phẩy' if ',' in match.group(0) else 'chấm'
260
- return f"{whole} {separator} {decimal_words}"
261
-
262
- # 2a. Dấu phẩy
263
- text = re.sub(r'(\d+),(\d+)', decimal_to_words, text)
264
- # 2b. Dấu chấm (1-2 chữ số thập phân)
265
- text = re.sub(r'(\d+)\.(\d{1,2})\b', decimal_to_words, text)
266
-
267
- return text
268
-
269
- def _read_two_digits(self, n):
270
- """Read two-digit numbers in Vietnamese."""
271
- if n < 10:
272
- return self.digits[n]
273
- elif n == 10:
274
- return "mười"
275
- elif n < 20:
276
- if n == 15:
277
- return "mười lăm"
278
- return f"mười {self.digits[n % 10]}"
279
- else:
280
- tens = n // 10
281
- ones = n % 10
282
- if ones == 0:
283
- return f"{self.digits[tens]} mươi"
284
- elif ones == 1:
285
- return f"{self.digits[tens]} mươi mốt"
286
- elif ones == 5:
287
- return f"{self.digits[tens]} mươi lăm"
288
- else:
289
- return f"{self.digits[tens]} mươi {self.digits[ones]}"
290
-
291
- def _read_three_digits(self, n):
292
- """Read three-digit numbers in Vietnamese."""
293
- if n < 100:
294
- return self._read_two_digits(n)
295
-
296
- hundreds = n // 100
297
- remainder = n % 100
298
- result = f"{self.digits[hundreds]} trăm"
299
-
300
- if remainder == 0:
301
- return result
302
- elif remainder < 10:
303
- result += f" lẻ {self.digits[remainder]}"
304
- else:
305
- result += f" {self._read_two_digits(remainder)}"
306
-
307
- return result
308
-
309
- def _convert_number_to_words(self, num):
310
- """Convert a number to Vietnamese words."""
311
- if num == 0:
312
- return "không"
313
-
314
- if num < 0:
315
- return f"âm {self._convert_number_to_words(-num)}"
316
-
317
- if num >= 1000000000:
318
- billion = num // 1000000000
319
- remainder = num % 1000000000
320
- result = f"{self._read_three_digits(billion)} tỷ"
321
- if remainder > 0:
322
- result += f" {self._convert_number_to_words(remainder)}"
323
- return result
324
-
325
- elif num >= 1000000:
326
- million = num // 1000000
327
- remainder = num % 1000000
328
- result = f"{self._read_three_digits(million)} triệu"
329
- if remainder > 0:
330
- result += f" {self._convert_number_to_words(remainder)}"
331
- return result
332
-
333
- elif num >= 1000:
334
- thousand = num // 1000
335
- remainder = num % 1000
336
- result = f"{self._read_three_digits(thousand)} nghìn"
337
- if remainder > 0:
338
- if remainder < 100:
339
- result += f" không trăm {self._read_two_digits(remainder)}"
340
- else:
341
- result += f" {self._read_three_digits(remainder)}"
342
- return result
343
-
344
- else:
345
- return self._read_three_digits(num)
346
-
347
- def _number_to_words(self, text):
348
- """Convert all remaining numbers to words."""
349
- def convert_number(match):
350
- num = int(match.group(0))
351
- return self._convert_number_to_words(num)
352
-
353
- text = re.sub(r'\b\d+\b', convert_number, text)
354
- return text
355
-
356
- def _normalize_special_chars(self, text):
357
- """Handle special characters."""
358
- text = text.replace('&', ' và ')
359
- text = text.replace('+', ' cộng ')
360
- text = text.replace('=', ' bằng ')
361
- text = text.replace('#', ' thăng ')
362
- text = re.sub(r'[\[\]\(\)\{\}]', ' ', text)
363
- text = re.sub(r'\s+[-–—]+\s+', ' ', text)
364
- text = re.sub(r'\.{2,}', ' ', text)
365
- text = re.sub(r'\s+\.\s+', ' ', text)
366
- text = re.sub(r'[^\w\sàáảãạăắằẳẵặâấầẩẫậèéẻẽẹêếềểễệìíỉĩịòóỏõọôốồổỗộơớờởỡợùúủũụưứừửữựỳýỷỹỵđ.,!?;:@%]', ' ', text)
367
- return text
368
-
369
- def _normalize_whitespace(self, text):
370
- """Normalize whitespace."""
371
- text = re.sub(r'\s+', ' ', text)
372
- text = text.strip()
373
- return text
374
-
375
-
376
- if __name__ == "__main__":
377
- normalizer = VietnameseTTSNormalizer()
378
-
379
- test_texts = [
380
- "Giá 2.500.000đ (giảm 50%), mua trước 14h30 ngày 15/12/2025",
381
- "Liên hệ: 0912-345-678 hoặc email@example.com",
382
- "Tốc độ 120km/h, trọng lượng 75kg",
383
- "Nhiệt độ 36,5°C, độ ẩm 80%",
384
- "Số pi = 3,14159",
385
- "Giá trị tăng 2.5M, đạt 10B",
386
- "Nhiệt độ -15°C vào mùa đông",
387
- "Điện áp 220V, công suất 2.5kW, tần số 50Hz",
388
- "Tôi đi lấy l nước về nhà",
389
- "Cần 5l nước cho công thức này",
390
- "Vận tốc ánh sáng 299792km/s",
391
- "Mật độ dân số 450 người/km2",
392
- "Công suất 100 W/m2",
393
- "Hôm nay 2025-01-15",
394
- "Gọi +84 912 345 678",
395
- "Nhiệt độ 25°C lúc 14:30:45",
396
- "Ngày 15/12/25",
397
- "Giá 3.140.159",
398
- ]
399
-
400
- print("=" * 80)
401
- print("VIETNAMESE TTS NORMALIZATION TEST")
402
- print("=" * 80)
403
-
404
- for text in test_texts:
405
- print(f"\n📝 Input: {text}")
406
- normalized = normalizer.normalize(text)
407
- print(f"🎵 Output: {normalized}")
408
- print("-" * 80)
 
1
+ import re
2
+
3
+ class VietnameseTTSNormalizer:
4
+ """
5
+ A text normalizer for Vietnamese Text-to-Speech systems.
6
+ Converts numbers, dates, units, and special characters into readable Vietnamese text.
7
+ """
8
+
9
+ def __init__(self):
10
+ self.units = {
11
+ 'km': 'ki lô mét', 'dm': 'đê xi mét', 'cm': 'xen ti mét',
12
+ 'mm': 'mi li mét', 'nm': 'na nô mét', 'µm': 'mic rô mét',
13
+ 'μm': 'mic rô mét', 'm': 'mét',
14
+
15
+ 'kg': 'ki lô gam', 'g': 'gam', 'mg': 'mi li gam',
16
+
17
+ 'km²': 'ki lô mét vuông', 'km2': 'ki lô mét vuông',
18
+ 'm²': 'mét vuông', 'm2': 'mét vuông',
19
+ 'cm²': 'xen ti mét vuông', 'cm2': 'xen ti mét vuông',
20
+ 'mm²': 'mi li mét vuông', 'mm2': 'mi li mét vuông',
21
+ 'ha': 'héc ta',
22
+
23
+ 'km³': 'ki lô mét khối', 'km3': 'ki lô mét khối',
24
+ 'm³': 'mét khối', 'm3': 'mét khối',
25
+ 'cm³': 'xen ti mét khối', 'cm3': 'xen ti mét khối',
26
+ 'mm³': 'mi li mét khối', 'mm3': 'mi li mét khối',
27
+ 'l': 'lít', 'dl': 'đê xi lít', 'ml': 'mi li lít', 'hl': 'héc tô lít',
28
+
29
+ 'v': 'vôn', 'kv': 'ki lô vôn', 'mv': 'mi li vôn',
30
+ 'a': 'am pe', 'ma': 'mi li am pe', 'ka': 'ki lô am pe',
31
+ 'w': 'oát', 'kw': 'ki lô oát', 'mw': 'mê ga oát', 'gw': 'gi ga oát',
32
+ 'kwh': 'ki lô oát giờ', 'mwh': 'mê ga oát giờ', 'wh': 'oát giờ',
33
+ 'ω': 'ôm', 'ohm': 'ôm', 'kω': 'ki lô ôm', 'mω': 'mê ga ôm',
34
+
35
+ 'hz': 'héc', 'khz': 'ki lô héc', 'mhz': 'mê ga héc', 'ghz': 'gi ga héc',
36
+
37
+ 'pa': 'pát cal', 'kpa': 'ki lô pát cal', 'mpa': 'mê ga pát cal',
38
+ 'bar': 'ba', 'mbar': 'mi li ba', 'atm': 'át mốt phia', 'psi': 'pi ét xai',
39
+
40
+ 'j': 'giun', 'kj': 'ki lô giun',
41
+ 'cal': 'ca lo', 'kcal': 'ki lô ca lo',
42
+ }
43
+
44
+ self.digits = ['không', 'một', 'hai', 'ba', 'bốn',
45
+ 'năm', 'sáu', 'bảy', 'tám', 'chín']
46
+
47
+ def normalize(self, text):
48
+ """Main normalization pipeline."""
49
+ text = text.lower()
50
+ text = self._normalize_temperature(text)
51
+ text = self._normalize_currency(text)
52
+ text = self._normalize_percentage(text)
53
+ text = self._normalize_units(text)
54
+ text = self._normalize_time(text)
55
+ text = self._normalize_date(text)
56
+ text = self._normalize_phone(text)
57
+ text = self._normalize_numbers(text)
58
+ text = self._number_to_words(text)
59
+ text = self._normalize_special_chars(text)
60
+ text = self._normalize_whitespace(text)
61
+ return text
62
+
63
+ def _normalize_temperature(self, text):
64
+ """Convert temperature notation to words."""
65
+ text = re.sub(r'-(\d+(?:[.,]\d+)?)\s*°\s*c\b', r'âm \1 độ xê', text, flags=re.IGNORECASE)
66
+ text = re.sub(r'-(\d+(?:[.,]\d+)?)\s*°\s*f\b', r'âm \1 độ ép', text, flags=re.IGNORECASE)
67
+ text = re.sub(r'(\d+(?:[.,]\d+)?)\s*°\s*c\b', r'\1 độ xê', text, flags=re.IGNORECASE)
68
+ text = re.sub(r'(\d+(?:[.,]\d+)?)\s*°\s*f\b', r'\1 độ ép', text, flags=re.IGNORECASE)
69
+ text = re.sub(r'°', ' độ ', text)
70
+ return text
71
+
72
+ def _normalize_currency(self, text):
73
+ """Convert currency notation to words."""
74
+ def decimal_currency(match):
75
+ whole = match.group(1)
76
+ decimal = match.group(2)
77
+ unit = match.group(3)
78
+ decimal_words = ' '.join([self.digits[int(d)] for d in decimal])
79
+ unit_map = {'k': 'nghìn', 'm': 'triệu', 'b': 'tỷ'}
80
+ unit_word = unit_map.get(unit.lower(), unit)
81
+ return f"{whole} phẩy {decimal_words} {unit_word}"
82
+
83
+ text = re.sub(r'(\d+)[.,](\d+)\s*([kmb])\b', decimal_currency, text, flags=re.IGNORECASE)
84
+ text = re.sub(r'(\d+)\s*k\b', r'\1 nghìn', text, flags=re.IGNORECASE)
85
+ text = re.sub(r'(\d+)\s*m\b', r'\1 triệu', text, flags=re.IGNORECASE)
86
+ text = re.sub(r'(\d+)\s*b\b', r'\1 tỷ', text, flags=re.IGNORECASE)
87
+ text = re.sub(r'(\d+(?:[.,]\d+)?)\s*đ\b', r'\1 đồng', text)
88
+ text = re.sub(r'(\d+(?:[.,]\d+)?)\s*vnd\b', r'\1 đồng', text, flags=re.IGNORECASE)
89
+ text = re.sub(r'\$\s*(\d+(?:[.,]\d+)?)', r'\1 đô la', text)
90
+ text = re.sub(r'(\d+(?:[.,]\d+)?)\s*\$', r'\1 đô la', text)
91
+ return text
92
+
93
+ def _normalize_percentage(self, text):
94
+ """Convert percentage to words."""
95
+ text = re.sub(r'(\d+(?:[.,]\d+)?)\s*%', r'\1 phần trăm', text)
96
+ return text
97
+
98
+ def _normalize_units(self, text):
99
+ """Convert measurement units to words."""
100
+ def expand_compound_with_number(match):
101
+ number = match.group(1)
102
+ unit1 = match.group(2).lower()
103
+ unit2 = match.group(3).lower()
104
+ full_unit1 = self.units.get(unit1, unit1)
105
+ full_unit2 = self.units.get(unit2, unit2)
106
+ return f"{number} {full_unit1} trên {full_unit2}"
107
+
108
+ def expand_compound_without_number(match):
109
+ unit1 = match.group(1).lower()
110
+ unit2 = match.group(2).lower()
111
+ full_unit1 = self.units.get(unit1, unit1)
112
+ full_unit2 = self.units.get(unit2, unit2)
113
+ return f"{full_unit1} trên {full_unit2}"
114
+
115
+ text = re.sub(r'(\d+(?:[.,]\d+)?)\s*([a-zA-Zμµ²³°]+)/([a-zA-Zμµ²³°0-9]+)\b',
116
+ expand_compound_with_number, text)
117
+ text = re.sub(r'\b([a-zA-Zμµ²³°]+)/([a-zA-Zμµ²³°0-9]+)\b',
118
+ expand_compound_without_number, text)
119
+
120
+ sorted_units = sorted(self.units.items(), key=lambda x: len(x[0]), reverse=True)
121
+ for unit, full_name in sorted_units:
122
+ pattern = r'(\d+(?:[.,]\d+)?)\s*' + re.escape(unit) + r'\b'
123
+ text = re.sub(pattern, rf'\1 {full_name}', text, flags=re.IGNORECASE)
124
+
125
+ for unit, full_name in sorted_units:
126
+ if any(c in unit for c in '²³°'):
127
+ pattern = r'\b' + re.escape(unit) + r'\b'
128
+ text = re.sub(pattern, full_name, text, flags=re.IGNORECASE)
129
+
130
+ return text
131
+
132
+ def _normalize_time(self, text):
133
+ """Convert time notation to words with validation."""
134
+
135
+ def validate_and_convert_time(match):
136
+ """Validate time components before converting."""
137
+ groups = match.groups()
138
+
139
+ # HH:MM:SS format
140
+ if len(groups) == 3:
141
+ hour, minute, second = groups
142
+ hour_int, minute_int, second_int = int(hour), int(minute), int(second)
143
+
144
+ # Validate ranges
145
+ if not (0 <= hour_int <= 23):
146
+ return match.group(0) # Return original if invalid
147
+ if not (0 <= minute_int <= 59):
148
+ return match.group(0)
149
+ if not (0 <= second_int <= 59):
150
+ return match.group(0)
151
+
152
+ return f"{hour} giờ {minute} phút {second} giây"
153
+
154
+ # HH:MM or HHhMM format
155
+ elif len(groups) == 2:
156
+ hour, minute = groups
157
+ hour_int, minute_int = int(hour), int(minute)
158
+
159
+ # Validate ranges
160
+ if not (0 <= hour_int <= 23):
161
+ return match.group(0)
162
+ if not (0 <= minute_int <= 59):
163
+ return match.group(0)
164
+
165
+ return f"{hour} giờ {minute} phút"
166
+
167
+ # HHh format
168
+ else:
169
+ hour = groups[0]
170
+ hour_int = int(hour)
171
+
172
+ if not (0 <= hour_int <= 23):
173
+ return match.group(0)
174
+
175
+ return f"{hour} giờ"
176
+
177
+ # Apply patterns with validation
178
+ text = re.sub(r'(\d{1,2}):(\d{2}):(\d{2})', validate_and_convert_time, text)
179
+ text = re.sub(r'(\d{1,2}):(\d{2})', validate_and_convert_time, text)
180
+ text = re.sub(r'(\d{1,2})h(\d{2})', validate_and_convert_time, text)
181
+ text = re.sub(r'(\d{1,2})h\b', validate_and_convert_time, text)
182
+
183
+ return text
184
+
185
+ def _normalize_date(self, text):
186
+ """Convert date notation to words with validation."""
187
+
188
+ def is_valid_date(day, month, year):
189
+ """Check if date components are valid."""
190
+ day, month, year = int(day), int(month), int(year)
191
+
192
+ # Basic range checks
193
+ if not (1 <= day <= 31):
194
+ return False
195
+ if not (1 <= month <= 12):
196
+ return False
197
+
198
+ return True
199
+
200
+ def date_to_text(match):
201
+ day, month, year = match.groups()
202
+ if is_valid_date(day, month, year):
203
+ return f"ngày {day} tháng {month} năm {year}"
204
+ return match.group(0) # Return original if invalid
205
+
206
+ def date_iso_to_text(match):
207
+ year, month, day = match.groups()
208
+ if is_valid_date(day, month, year):
209
+ return f"ngày {day} tháng {month} năm {year}"
210
+ return match.group(0)
211
+
212
+ def date_short_year(match):
213
+ day, month, year = match.groups()
214
+ full_year = f"20{year}" if int(year) < 50 else f"19{year}"
215
+ if is_valid_date(day, month, full_year):
216
+ return f"ngày {day} tháng {month} năm {full_year}"
217
+ return match.group(0)
218
+
219
+ # Apply patterns with validation
220
+ text = re.sub(r'\bngày\s+(\d{1,2})[/\-](\d{1,2})[/\-](\d{4})\b',
221
+ lambda m: date_to_text(m).replace('ngày ngày', 'ngày'), text)
222
+ text = re.sub(r'\bngày\s+(\d{1,2})[/\-](\d{1,2})[/\-](\d{2})\b',
223
+ lambda m: date_short_year(m).replace('ngày ngày', 'ngày'), text)
224
+ text = re.sub(r'\b(\d{4})-(\d{1,2})-(\d{1,2})\b', date_iso_to_text, text)
225
+ text = re.sub(r'\b(\d{1,2})[/\-](\d{1,2})[/\-](\d{4})\b', date_to_text, text)
226
+ text = re.sub(r'\b(\d{1,2})[/\-](\d{1,2})[/\-](\d{2})\b', date_short_year, text)
227
+
228
+ return text
229
+
230
+ def _normalize_phone(self, text):
231
+ """Convert phone numbers to digit-by-digit reading."""
232
+ def phone_to_text(match):
233
+ phone = match.group(0)
234
+ phone = re.sub(r'[^\d]', '', phone)
235
+
236
+ if phone.startswith('84') and len(phone) >= 10:
237
+ phone = '0' + phone[2:]
238
+
239
+ if 10 <= len(phone) <= 11:
240
+ words = [self.digits[int(d)] for d in phone]
241
+ return ' '.join(words) + ' '
242
+
243
+ return match.group(0)
244
+
245
+ text = re.sub(r'(\+84|84)[\s\-\.]?\d[\d\s\-\.]{7,}', phone_to_text, text)
246
+ text = re.sub(r'\b0\d[\d\s\-\.]{8,}', phone_to_text, text)
247
+ return text
248
+
249
+ def _normalize_numbers(self, text):
250
+ text = re.sub(r'(\d+(?:[,.]\d+)?)%', lambda m: f'{m.group(1)} phần trăm', text)
251
+ # 1. Xóa dấu thousand separator trước
252
+ text = re.sub(r'(\d{1,3})(?:\.(\d{3}))+', lambda m: m.group(0).replace('.', ''), text)
253
+
254
+ # 2. Chuyển số thập phân thành chữ
255
+ def decimal_to_words(match):
256
+ whole = match.group(1)
257
+ decimal = match.group(2)
258
+ decimal_words = ' '.join([self.digits[int(d)] for d in decimal])
259
+ separator = 'phẩy' if ',' in match.group(0) else 'chấm'
260
+ return f"{whole} {separator} {decimal_words}"
261
+
262
+ # 2a. Dấu phẩy
263
+ text = re.sub(r'(\d+),(\d+)', decimal_to_words, text)
264
+ # 2b. Dấu chấm (1-2 chữ số thập phân)
265
+ text = re.sub(r'(\d+)\.(\d{1,2})\b', decimal_to_words, text)
266
+
267
+ return text
268
+
269
+ def _read_two_digits(self, n):
270
+ """Read two-digit numbers in Vietnamese."""
271
+ if n < 10:
272
+ return self.digits[n]
273
+ elif n == 10:
274
+ return "mười"
275
+ elif n < 20:
276
+ if n == 15:
277
+ return "mười lăm"
278
+ return f"mười {self.digits[n % 10]}"
279
+ else:
280
+ tens = n // 10
281
+ ones = n % 10
282
+ if ones == 0:
283
+ return f"{self.digits[tens]} mươi"
284
+ elif ones == 1:
285
+ return f"{self.digits[tens]} mươi mốt"
286
+ elif ones == 5:
287
+ return f"{self.digits[tens]} mươi lăm"
288
+ else:
289
+ return f"{self.digits[tens]} mươi {self.digits[ones]}"
290
+
291
+ def _read_three_digits(self, n):
292
+ """Read three-digit numbers in Vietnamese."""
293
+ if n < 100:
294
+ return self._read_two_digits(n)
295
+
296
+ hundreds = n // 100
297
+ remainder = n % 100
298
+ result = f"{self.digits[hundreds]} trăm"
299
+
300
+ if remainder == 0:
301
+ return result
302
+ elif remainder < 10:
303
+ result += f" lẻ {self.digits[remainder]}"
304
+ else:
305
+ result += f" {self._read_two_digits(remainder)}"
306
+
307
+ return result
308
+
309
+ def _convert_number_to_words(self, num):
310
+ """Convert a number to Vietnamese words."""
311
+ if num == 0:
312
+ return "không"
313
+
314
+ if num < 0:
315
+ return f"âm {self._convert_number_to_words(-num)}"
316
+
317
+ if num >= 1000000000:
318
+ billion = num // 1000000000
319
+ remainder = num % 1000000000
320
+ result = f"{self._read_three_digits(billion)} tỷ"
321
+ if remainder > 0:
322
+ result += f" {self._convert_number_to_words(remainder)}"
323
+ return result
324
+
325
+ elif num >= 1000000:
326
+ million = num // 1000000
327
+ remainder = num % 1000000
328
+ result = f"{self._read_three_digits(million)} triệu"
329
+ if remainder > 0:
330
+ result += f" {self._convert_number_to_words(remainder)}"
331
+ return result
332
+
333
+ elif num >= 1000:
334
+ thousand = num // 1000
335
+ remainder = num % 1000
336
+ result = f"{self._read_three_digits(thousand)} nghìn"
337
+ if remainder > 0:
338
+ if remainder < 100:
339
+ result += f" không trăm {self._read_two_digits(remainder)}"
340
+ else:
341
+ result += f" {self._read_three_digits(remainder)}"
342
+ return result
343
+
344
+ else:
345
+ return self._read_three_digits(num)
346
+
347
+ def _number_to_words(self, text):
348
+ """Convert all remaining numbers to words."""
349
+ def convert_number(match):
350
+ num = int(match.group(0))
351
+ return self._convert_number_to_words(num)
352
+
353
+ text = re.sub(r'\b\d+\b', convert_number, text)
354
+ return text
355
+
356
+ def _normalize_special_chars(self, text):
357
+ """Handle special characters."""
358
+ text = text.replace('&', ' và ')
359
+ text = text.replace('+', ' cộng ')
360
+ text = text.replace('=', ' bằng ')
361
+ text = text.replace('#', ' thăng ')
362
+ text = re.sub(r'[\[\]\(\)\{\}]', ' ', text)
363
+ text = re.sub(r'\s+[-–—]+\s+', ' ', text)
364
+ text = re.sub(r'\.{2,}', ' ', text)
365
+ text = re.sub(r'\s+\.\s+', ' ', text)
366
+ text = re.sub(r'[^\w\sàáảãạăắằẳẵặâấầẩẫậèéẻẽẹêếềểễệìíỉĩịòóỏõọôốồổỗộơớờởỡợùúủũụưứừửữựỳýỷỹỵđ.,!?;:@%]', ' ', text)
367
+ return text
368
+
369
+ def _normalize_whitespace(self, text):
370
+ """Normalize whitespace."""
371
+ text = re.sub(r'\s+', ' ', text)
372
+ text = text.strip()
373
+ return text
374
+
375
+
376
+ if __name__ == "__main__":
377
+ normalizer = VietnameseTTSNormalizer()
378
+
379
+ test_texts = [
380
+ "Giá 2.500.000đ (giảm 50%), mua trước 14h30 ngày 15/12/2025",
381
+ "Liên hệ: 0912-345-678 hoặc email@example.com",
382
+ "Tốc độ 120km/h, trọng lượng 75kg",
383
+ "Nhiệt độ 36,5°C, độ ẩm 80%",
384
+ "Số pi = 3,14159",
385
+ "Giá trị tăng 2.5M, đạt 10B",
386
+ "Nhiệt độ -15°C vào mùa đông",
387
+ "Điện áp 220V, công suất 2.5kW, tần số 50Hz",
388
+ "Tôi đi lấy l nước về nhà",
389
+ "Cần 5l nước cho công thức này",
390
+ "Vận tốc ánh sáng 299792km/s",
391
+ "Mật độ dân số 450 người/km2",
392
+ "Công suất 100 W/m2",
393
+ "Hôm nay 2025-01-15",
394
+ "Gọi +84 912 345 678",
395
+ "Nhiệt độ 25°C lúc 14:30:45",
396
+ "Ngày 15/12/25",
397
+ "Giá 3.140.159",
398
+ ]
399
+
400
+ print("=" * 80)
401
+ print("VIETNAMESE TTS NORMALIZATION TEST")
402
+ print("=" * 80)
403
+
404
+ for text in test_texts:
405
+ print(f"\n📝 Input: {text}")
406
+ normalized = normalizer.normalize(text)
407
+ print(f"🎵 Output: {normalized}")
408
+ print("-" * 80)
utils/phoneme_dict.json CHANGED
The diff for this file is too large to render. See raw diff
 
utils/phonemize_text.py CHANGED
@@ -1,150 +1,150 @@
1
- import os
2
- import json
3
- import platform
4
- import glob
5
- from phonemizer import phonemize
6
- from phonemizer.backend.espeak.espeak import EspeakWrapper
7
- from utils.normalize_text import VietnameseTTSNormalizer
8
-
9
- # Configuration
10
- PHONEME_DICT_PATH = os.getenv(
11
- 'PHONEME_DICT_PATH',
12
- os.path.join(os.path.dirname(__file__), "phoneme_dict.json")
13
- )
14
-
15
- def load_phoneme_dict(path=PHONEME_DICT_PATH):
16
- """Load phoneme dictionary from JSON file."""
17
- try:
18
- with open(path, "r", encoding="utf-8") as f:
19
- return json.load(f)
20
- except FileNotFoundError:
21
- raise FileNotFoundError(
22
- f"Phoneme dictionary not found at {path}. "
23
- "Please create it or set PHONEME_DICT_PATH environment variable."
24
- )
25
-
26
- def setup_espeak_library():
27
- """Configure eSpeak library path based on operating system."""
28
- system = platform.system()
29
-
30
- if system == "Windows":
31
- _setup_windows_espeak()
32
- elif system == "Linux":
33
- _setup_linux_espeak()
34
- elif system == "Darwin":
35
- _setup_macos_espeak()
36
- else:
37
- raise OSError(
38
- f"Unsupported OS: {system}. "
39
- "Only Windows, Linux, and macOS are supported."
40
- )
41
-
42
- def _setup_windows_espeak():
43
- """Setup eSpeak for Windows."""
44
- default_path = r"C:\Program Files\eSpeak NG\libespeak-ng.dll"
45
- if os.path.exists(default_path):
46
- EspeakWrapper.set_library(default_path)
47
- else:
48
- raise FileNotFoundError(
49
- f"eSpeak library not found at {default_path}. "
50
- "Please install eSpeak NG from: https://github.com/espeak-ng/espeak-ng/releases"
51
- )
52
-
53
- def _setup_linux_espeak():
54
- """Setup eSpeak for Linux."""
55
- search_patterns = [
56
- "/usr/lib/x86_64-linux-gnu/libespeak-ng.so*",
57
- "/usr/lib/x86_64-linux-gnu/libespeak.so*",
58
- "/usr/lib/libespeak-ng.so*",
59
- "/usr/lib64/libespeak-ng.so*",
60
- "/usr/local/lib/libespeak-ng.so*",
61
- ]
62
-
63
- for pattern in search_patterns:
64
- matches = glob.glob(pattern)
65
- if matches:
66
- EspeakWrapper.set_library(sorted(matches, key=len)[0])
67
- return
68
-
69
- raise RuntimeError(
70
- "eSpeak NG library not found. Install with:\n"
71
- " Ubuntu/Debian: sudo apt-get install espeak-ng\n"
72
- " Fedora: sudo dnf install espeak-ng\n"
73
- " Arch: sudo pacman -S espeak-ng\n"
74
- "See: https://github.com/pnnbao97/VieNeu-TTS/issues/5"
75
- )
76
-
77
- def _setup_macos_espeak():
78
- """Setup eSpeak for macOS."""
79
- espeak_lib = os.environ.get('PHONEMIZER_ESPEAK_LIBRARY')
80
-
81
- paths_to_check = [
82
- espeak_lib,
83
- "/opt/homebrew/lib/libespeak-ng.dylib", # Apple Silicon
84
- "/usr/local/lib/libespeak-ng.dylib", # Intel
85
- "/opt/local/lib/libespeak-ng.dylib", # MacPorts
86
- ]
87
-
88
- for path in paths_to_check:
89
- if path and os.path.exists(path):
90
- EspeakWrapper.set_library(path)
91
- return
92
-
93
- raise FileNotFoundError(
94
- "eSpeak library not found. Install with:\n"
95
- " brew install espeak-ng\n"
96
- "Or set: export PHONEMIZER_ESPEAK_LIBRARY=/path/to/libespeak-ng.dylib"
97
- )
98
-
99
- # Initialize
100
- try:
101
- setup_espeak_library()
102
- phoneme_dict = load_phoneme_dict()
103
- normalizer = VietnameseTTSNormalizer()
104
- except Exception as e:
105
- print(f"Initialization error: {e}")
106
- raise
107
-
108
- def phonemize_text(text: str) -> str:
109
- """Convert text to phonemes using phonemizer."""
110
- text = normalizer.normalize(text)
111
- return phonemize(
112
- text,
113
- language="vi",
114
- backend="espeak",
115
- preserve_punctuation=True,
116
- with_stress=True,
117
- language_switch="remove-flags"
118
- )
119
-
120
- def phonemize_with_dict(text: str, phoneme_dict=phoneme_dict) -> str:
121
- """Phonemize text with dictionary lookup."""
122
- text = normalizer.normalize(text)
123
- words = text.split()
124
- result = []
125
-
126
- for word in words:
127
- if word in phoneme_dict:
128
- phone_word = phoneme_dict[word]
129
- else:
130
- try:
131
- phone_word = phonemize(
132
- word,
133
- language='vi',
134
- backend='espeak',
135
- preserve_punctuation=True,
136
- with_stress=True,
137
- language_switch='remove-flags'
138
- )
139
-
140
- if word.lower().startswith('r'):
141
- phone_word = 'ɹ' + phone_word[1:]
142
-
143
- phoneme_dict[word] = phone_word
144
- except Exception as e:
145
- print(f"Warning: Could not phonemize '{word}': {e}")
146
- phone_word = word
147
-
148
- result.append(phone_word)
149
-
150
  return ' '.join(result)
 
1
+ import os
2
+ import json
3
+ import platform
4
+ import glob
5
+ from phonemizer import phonemize
6
+ from phonemizer.backend.espeak.espeak import EspeakWrapper
7
+ from utils.normalize_text import VietnameseTTSNormalizer
8
+
9
+ # Configuration
10
+ PHONEME_DICT_PATH = os.getenv(
11
+ 'PHONEME_DICT_PATH',
12
+ os.path.join(os.path.dirname(__file__), "phoneme_dict.json")
13
+ )
14
+
15
+ def load_phoneme_dict(path=PHONEME_DICT_PATH):
16
+ """Load phoneme dictionary from JSON file."""
17
+ try:
18
+ with open(path, "r", encoding="utf-8") as f:
19
+ return json.load(f)
20
+ except FileNotFoundError:
21
+ raise FileNotFoundError(
22
+ f"Phoneme dictionary not found at {path}. "
23
+ "Please create it or set PHONEME_DICT_PATH environment variable."
24
+ )
25
+
26
+ def setup_espeak_library():
27
+ """Configure eSpeak library path based on operating system."""
28
+ system = platform.system()
29
+
30
+ if system == "Windows":
31
+ _setup_windows_espeak()
32
+ elif system == "Linux":
33
+ _setup_linux_espeak()
34
+ elif system == "Darwin":
35
+ _setup_macos_espeak()
36
+ else:
37
+ raise OSError(
38
+ f"Unsupported OS: {system}. "
39
+ "Only Windows, Linux, and macOS are supported."
40
+ )
41
+
42
+ def _setup_windows_espeak():
43
+ """Setup eSpeak for Windows."""
44
+ default_path = r"C:\Program Files\eSpeak NG\libespeak-ng.dll"
45
+ if os.path.exists(default_path):
46
+ EspeakWrapper.set_library(default_path)
47
+ else:
48
+ raise FileNotFoundError(
49
+ f"eSpeak library not found at {default_path}. "
50
+ "Please install eSpeak NG from: https://github.com/espeak-ng/espeak-ng/releases"
51
+ )
52
+
53
+ def _setup_linux_espeak():
54
+ """Setup eSpeak for Linux."""
55
+ search_patterns = [
56
+ "/usr/lib/x86_64-linux-gnu/libespeak-ng.so*",
57
+ "/usr/lib/x86_64-linux-gnu/libespeak.so*",
58
+ "/usr/lib/libespeak-ng.so*",
59
+ "/usr/lib64/libespeak-ng.so*",
60
+ "/usr/local/lib/libespeak-ng.so*",
61
+ ]
62
+
63
+ for pattern in search_patterns:
64
+ matches = glob.glob(pattern)
65
+ if matches:
66
+ EspeakWrapper.set_library(sorted(matches, key=len)[0])
67
+ return
68
+
69
+ raise RuntimeError(
70
+ "eSpeak NG library not found. Install with:\n"
71
+ " Ubuntu/Debian: sudo apt-get install espeak-ng\n"
72
+ " Fedora: sudo dnf install espeak-ng\n"
73
+ " Arch: sudo pacman -S espeak-ng\n"
74
+ "See: https://github.com/pnnbao97/VieNeu-TTS/issues/5"
75
+ )
76
+
77
+ def _setup_macos_espeak():
78
+ """Setup eSpeak for macOS."""
79
+ espeak_lib = os.environ.get('PHONEMIZER_ESPEAK_LIBRARY')
80
+
81
+ paths_to_check = [
82
+ espeak_lib,
83
+ "/opt/homebrew/lib/libespeak-ng.dylib", # Apple Silicon
84
+ "/usr/local/lib/libespeak-ng.dylib", # Intel
85
+ "/opt/local/lib/libespeak-ng.dylib", # MacPorts
86
+ ]
87
+
88
+ for path in paths_to_check:
89
+ if path and os.path.exists(path):
90
+ EspeakWrapper.set_library(path)
91
+ return
92
+
93
+ raise FileNotFoundError(
94
+ "eSpeak library not found. Install with:\n"
95
+ " brew install espeak-ng\n"
96
+ "Or set: export PHONEMIZER_ESPEAK_LIBRARY=/path/to/libespeak-ng.dylib"
97
+ )
98
+
99
+ # Initialize
100
+ try:
101
+ setup_espeak_library()
102
+ phoneme_dict = load_phoneme_dict()
103
+ normalizer = VietnameseTTSNormalizer()
104
+ except Exception as e:
105
+ print(f"Initialization error: {e}")
106
+ raise
107
+
108
+ def phonemize_text(text: str) -> str:
109
+ """Convert text to phonemes using phonemizer."""
110
+ text = normalizer.normalize(text)
111
+ return phonemize(
112
+ text,
113
+ language="vi",
114
+ backend="espeak",
115
+ preserve_punctuation=True,
116
+ with_stress=True,
117
+ language_switch="remove-flags"
118
+ )
119
+
120
+ def phonemize_with_dict(text: str, phoneme_dict=phoneme_dict) -> str:
121
+ """Phonemize text with dictionary lookup."""
122
+ text = normalizer.normalize(text)
123
+ words = text.split()
124
+ result = []
125
+
126
+ for word in words:
127
+ if word in phoneme_dict:
128
+ phone_word = phoneme_dict[word]
129
+ else:
130
+ try:
131
+ phone_word = phonemize(
132
+ word,
133
+ language='vi',
134
+ backend='espeak',
135
+ preserve_punctuation=True,
136
+ with_stress=True,
137
+ language_switch='remove-flags'
138
+ )
139
+
140
+ if word.lower().startswith('r'):
141
+ phone_word = 'ɹ' + phone_word[1:]
142
+
143
+ phoneme_dict[word] = phone_word
144
+ except Exception as e:
145
+ print(f"Warning: Could not phonemize '{word}': {e}")
146
+ phone_word = word
147
+
148
+ result.append(phone_word)
149
+
150
  return ' '.join(result)
vieneu_tts/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .vieneu_tts import VieNeuTTS, FastVieNeuTTS
2
+
3
+ __all__ = ["VieNeuTTS", "FastVieNeuTTS"]
4
+
vieneu_tts/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (241 Bytes). View file
 
vieneu_tts/__pycache__/vieneu_tts.cpython-312.pyc ADDED
Binary file (39 kB). View file
 
vieneu_tts/__pycache__/vieneu_tts_gpu.cpython-312.pyc ADDED
Binary file (24.1 kB). View file
 
vieneu_tts/vieneu_tts.py ADDED
@@ -0,0 +1,869 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Generator
3
+ import librosa
4
+ import numpy as np
5
+ import torch
6
+ from neucodec import NeuCodec, DistillNeuCodec
7
+ from utils.phonemize_text import phonemize_with_dict
8
+ from collections import defaultdict
9
+ from concurrent.futures import ThreadPoolExecutor
10
+ import re
11
+ import gc
12
+
13
+ # ============================================================================
14
+ # Shared Utilities
15
+ # ============================================================================
16
+
17
+ def _linear_overlap_add(frames: list[np.ndarray], stride: int) -> np.ndarray:
18
+ """Linear overlap-add for smooth audio concatenation"""
19
+ assert len(frames)
20
+ dtype = frames[0].dtype
21
+ shape = frames[0].shape[:-1]
22
+
23
+ total_size = 0
24
+ for i, frame in enumerate(frames):
25
+ frame_end = stride * i + frame.shape[-1]
26
+ total_size = max(total_size, frame_end)
27
+
28
+ sum_weight = np.zeros(total_size, dtype=dtype)
29
+ out = np.zeros(*shape, total_size, dtype=dtype)
30
+
31
+ offset: int = 0
32
+ for frame in frames:
33
+ frame_length = frame.shape[-1]
34
+ t = np.linspace(0, 1, frame_length + 2, dtype=dtype)[1:-1]
35
+ weight = np.abs(0.5 - (t - 0.5))
36
+
37
+ out[..., offset : offset + frame_length] += weight * frame
38
+ sum_weight[offset : offset + frame_length] += weight
39
+ offset += stride
40
+ assert sum_weight.min() > 0
41
+ return out / sum_weight
42
+
43
+
44
+ def _compile_codec_with_triton(codec):
45
+ """Compile codec with Triton for faster decoding (Windows/Linux compatible)"""
46
+ try:
47
+ import triton
48
+
49
+ if hasattr(codec, 'dec') and hasattr(codec.dec, 'resblocks'):
50
+ if len(codec.dec.resblocks) > 2:
51
+ codec.dec.resblocks[2].forward = torch.compile(
52
+ codec.dec.resblocks[2].forward,
53
+ mode="reduce-overhead",
54
+ dynamic=True
55
+ )
56
+ print(" ✅ Triton compilation enabled for codec")
57
+ return True
58
+
59
+ except ImportError:
60
+ print(" ⚠️ Triton not found. Install for faster speed:")
61
+ print(" • Linux: pip install triton")
62
+ print(" • Windows: pip install triton-windows")
63
+ print(" (Optional but recommended)")
64
+ return False
65
+
66
+
67
+ # ============================================================================
68
+ # VieNeuTTS - Standard implementation (CPU/GPU compatible)
69
+ # Supports: PyTorch Transformers, GGUF/GGML quantized models
70
+ # ============================================================================
71
+
72
+ class VieNeuTTS:
73
+ """
74
+ Standard VieNeu-TTS implementation.
75
+
76
+ Supports:
77
+ - PyTorch + Transformers backend (CPU/GPU)
78
+ - GGUF quantized models via llama-cpp-python (CPU optimized)
79
+
80
+ Use this for:
81
+ - CPU-only environments
82
+ - Standard PyTorch workflows
83
+ - GGUF quantized models
84
+ """
85
+
86
+ def __init__(
87
+ self,
88
+ backbone_repo="pnnbao-ump/VieNeu-TTS",
89
+ backbone_device="cpu",
90
+ codec_repo="neuphonic/neucodec",
91
+ codec_device="cpu",
92
+ ):
93
+ """
94
+ Initialize VieNeu-TTS.
95
+
96
+ Args:
97
+ backbone_repo: Model repository or path to GGUF file
98
+ backbone_device: Device for backbone ('cpu', 'cuda', 'gpu')
99
+ codec_repo: Codec repository
100
+ codec_device: Device for codec
101
+ """
102
+
103
+ # Constants
104
+ self.sample_rate = 24_000
105
+ self.max_context = 2048
106
+ self.hop_length = 480
107
+ self.streaming_overlap_frames = 1
108
+ self.streaming_frames_per_chunk = 25
109
+ self.streaming_lookforward = 5
110
+ self.streaming_lookback = 50
111
+ self.streaming_stride_samples = self.streaming_frames_per_chunk * self.hop_length
112
+
113
+ # Flags
114
+ self._is_quantized_model = False
115
+ self._is_onnx_codec = False
116
+
117
+ # HF tokenizer
118
+ self.tokenizer = None
119
+
120
+ # Load models
121
+ self._load_backbone(backbone_repo, backbone_device)
122
+ self._load_codec(codec_repo, codec_device)
123
+
124
+ def _load_backbone(self, backbone_repo, backbone_device):
125
+ print(f"Loading backbone from: {backbone_repo} on {backbone_device} ...")
126
+
127
+ if backbone_repo.lower().endswith("gguf") or "gguf" in backbone_repo.lower():
128
+ try:
129
+ from llama_cpp import Llama
130
+ except ImportError as e:
131
+ raise ImportError(
132
+ "Failed to import `llama_cpp`. "
133
+ "Please install it with:\n"
134
+ " pip install llama-cpp-python"
135
+ ) from e
136
+ self.backbone = Llama.from_pretrained(
137
+ repo_id=backbone_repo,
138
+ filename="*.gguf",
139
+ verbose=False,
140
+ n_gpu_layers=-1 if backbone_device == "gpu" else 0,
141
+ n_ctx=self.max_context,
142
+ mlock=True,
143
+ flash_attn=True if backbone_device == "gpu" else False,
144
+ )
145
+ self._is_quantized_model = True
146
+
147
+ else:
148
+ from transformers import AutoTokenizer, AutoModelForCausalLM
149
+ self.tokenizer = AutoTokenizer.from_pretrained(backbone_repo)
150
+ self.backbone = AutoModelForCausalLM.from_pretrained(backbone_repo).to(
151
+ torch.device(backbone_device)
152
+ )
153
+
154
+ def _load_codec(self, codec_repo, codec_device):
155
+ print(f"Loading codec from: {codec_repo} on {codec_device} ...")
156
+ match codec_repo:
157
+ case "neuphonic/neucodec":
158
+ self.codec = NeuCodec.from_pretrained(codec_repo)
159
+ self.codec.eval().to(codec_device)
160
+ case "neuphonic/distill-neucodec":
161
+ self.codec = DistillNeuCodec.from_pretrained(codec_repo)
162
+ self.codec.eval().to(codec_device)
163
+ case "neuphonic/neucodec-onnx-decoder":
164
+ if codec_device != "cpu":
165
+ raise ValueError("Onnx decoder only currently runs on CPU.")
166
+ try:
167
+ from neucodec import NeuCodecOnnxDecoder
168
+ except ImportError as e:
169
+ raise ImportError(
170
+ "Failed to import the onnx decoder."
171
+ "Ensure you have onnxruntime installed as well as neucodec >= 0.0.4."
172
+ ) from e
173
+ self.codec = NeuCodecOnnxDecoder.from_pretrained(codec_repo)
174
+ self._is_onnx_codec = True
175
+ case _:
176
+ raise ValueError(f"Unsupported codec repository: {codec_repo}")
177
+
178
+ def encode_reference(self, ref_audio_path: str | Path):
179
+ """Encode reference audio to codes"""
180
+ wav, _ = librosa.load(ref_audio_path, sr=16000, mono=True)
181
+ wav_tensor = torch.from_numpy(wav).float().unsqueeze(0).unsqueeze(0) # [1, 1, T]
182
+ with torch.no_grad():
183
+ ref_codes = self.codec.encode_code(audio_or_path=wav_tensor).squeeze(0).squeeze(0)
184
+ return ref_codes
185
+
186
+ def infer(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> np.ndarray:
187
+ """
188
+ Perform inference to generate speech from text using the TTS model and reference audio.
189
+
190
+ Args:
191
+ text (str): Input text to be converted to speech.
192
+ ref_codes (np.ndarray | torch.tensor): Encoded reference.
193
+ ref_text (str): Reference text for reference audio.
194
+ Returns:
195
+ np.ndarray: Generated speech waveform.
196
+ """
197
+
198
+ # Generate tokens
199
+ if self._is_quantized_model:
200
+ output_str = self._infer_ggml(ref_codes, ref_text, text)
201
+ else:
202
+ prompt_ids = self._apply_chat_template(ref_codes, ref_text, text)
203
+ output_str = self._infer_torch(prompt_ids)
204
+
205
+ # Decode
206
+ wav = self._decode(output_str)
207
+
208
+ return wav
209
+
210
+ def infer_stream(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> Generator[np.ndarray, None, None]:
211
+ """
212
+ Perform streaming inference to generate speech from text using the TTS model and reference audio.
213
+
214
+ Args:
215
+ text (str): Input text to be converted to speech.
216
+ ref_codes (np.ndarray | torch.tensor): Encoded reference.
217
+ ref_text (str): Reference text for reference audio.
218
+ Yields:
219
+ np.ndarray: Generated speech waveform.
220
+ """
221
+
222
+ if self._is_quantized_model:
223
+ return self._infer_stream_ggml(ref_codes, ref_text, text)
224
+ else:
225
+ raise NotImplementedError("Streaming is not implemented for the torch backend!")
226
+
227
+ def _decode(self, codes: str):
228
+ """Decode speech tokens to audio waveform."""
229
+ # Extract speech token IDs using regex
230
+ speech_ids = [int(num) for num in re.findall(r"<\|speech_(\d+)\|>", codes)]
231
+
232
+ if len(speech_ids) == 0:
233
+ raise ValueError(
234
+ "No valid speech tokens found in the output. "
235
+ "The model may not have generated proper speech tokens."
236
+ )
237
+
238
+ # Onnx decode
239
+ if self._is_onnx_codec:
240
+ codes = np.array(speech_ids, dtype=np.int32)[np.newaxis, np.newaxis, :]
241
+ recon = self.codec.decode_code(codes)
242
+ # Torch decode
243
+ else:
244
+ with torch.no_grad():
245
+ codes = torch.tensor(speech_ids, dtype=torch.long)[None, None, :].to(
246
+ self.codec.device
247
+ )
248
+ recon = self.codec.decode_code(codes).cpu().numpy()
249
+
250
+ return recon[0, 0, :]
251
+
252
+ def _apply_chat_template(self, ref_codes: list[int], ref_text: str, input_text: str) -> list[int]:
253
+ input_text = phonemize_with_dict(ref_text) + " " + phonemize_with_dict(input_text)
254
+
255
+ speech_replace = self.tokenizer.convert_tokens_to_ids("<|SPEECH_REPLACE|>")
256
+ speech_gen_start = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_START|>")
257
+ text_replace = self.tokenizer.convert_tokens_to_ids("<|TEXT_REPLACE|>")
258
+ text_prompt_start = self.tokenizer.convert_tokens_to_ids("<|TEXT_PROMPT_START|>")
259
+ text_prompt_end = self.tokenizer.convert_tokens_to_ids("<|TEXT_PROMPT_END|>")
260
+
261
+ input_ids = self.tokenizer.encode(input_text, add_special_tokens=False)
262
+ chat = """user: Convert the text to speech:<|TEXT_REPLACE|>\nassistant:<|SPEECH_REPLACE|>"""
263
+ ids = self.tokenizer.encode(chat)
264
+
265
+ text_replace_idx = ids.index(text_replace)
266
+ ids = (
267
+ ids[:text_replace_idx]
268
+ + [text_prompt_start]
269
+ + input_ids
270
+ + [text_prompt_end]
271
+ + ids[text_replace_idx + 1 :] # noqa
272
+ )
273
+
274
+ speech_replace_idx = ids.index(speech_replace)
275
+ codes_str = "".join([f"<|speech_{i}|>" for i in ref_codes])
276
+ codes = self.tokenizer.encode(codes_str, add_special_tokens=False)
277
+ ids = ids[:speech_replace_idx] + [speech_gen_start] + list(codes)
278
+
279
+ return ids
280
+
281
+ def _infer_torch(self, prompt_ids: list[int]) -> str:
282
+ prompt_tensor = torch.tensor(prompt_ids).unsqueeze(0).to(self.backbone.device)
283
+ speech_end_id = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>")
284
+ with torch.no_grad():
285
+ output_tokens = self.backbone.generate(
286
+ prompt_tensor,
287
+ max_length=self.max_context,
288
+ eos_token_id=speech_end_id,
289
+ do_sample=True,
290
+ temperature=1.0,
291
+ top_k=50,
292
+ use_cache=True,
293
+ min_new_tokens=50,
294
+ )
295
+ input_length = prompt_tensor.shape[-1]
296
+ output_str = self.tokenizer.decode(
297
+ output_tokens[0, input_length:].cpu().numpy().tolist(), add_special_tokens=False
298
+ )
299
+ return output_str
300
+
301
+ def _infer_ggml(self, ref_codes: list[int], ref_text: str, input_text: str) -> str:
302
+ ref_text = phonemize_with_dict(ref_text)
303
+ input_text = phonemize_with_dict(input_text)
304
+
305
+ codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes])
306
+ prompt = (
307
+ f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text} {input_text}"
308
+ f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}"
309
+ )
310
+ output = self.backbone(
311
+ prompt,
312
+ max_tokens=self.max_context,
313
+ temperature=1.0,
314
+ top_k=50,
315
+ stop=["<|SPEECH_GENERATION_END|>"],
316
+ )
317
+ output_str = output["choices"][0]["text"]
318
+ return output_str
319
+
320
+ def _infer_stream_ggml(self, ref_codes: torch.Tensor, ref_text: str, input_text: str) -> Generator[np.ndarray, None, None]:
321
+ ref_text = phonemize_with_dict(ref_text)
322
+ input_text = phonemize_with_dict(input_text)
323
+
324
+ codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes])
325
+ prompt = (
326
+ f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text} {input_text}"
327
+ f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}"
328
+ )
329
+
330
+ audio_cache: list[np.ndarray] = []
331
+ token_cache: list[str] = [f"<|speech_{idx}|>" for idx in ref_codes]
332
+ n_decoded_samples: int = 0
333
+ n_decoded_tokens: int = len(ref_codes)
334
+
335
+ for item in self.backbone(
336
+ prompt,
337
+ max_tokens=self.max_context,
338
+ temperature=1.0,
339
+ top_k=50,
340
+ stop=["<|SPEECH_GENERATION_END|>"],
341
+ stream=True
342
+ ):
343
+ output_str = item["choices"][0]["text"]
344
+ token_cache.append(output_str)
345
+
346
+ if len(token_cache[n_decoded_tokens:]) >= self.streaming_frames_per_chunk + self.streaming_lookforward:
347
+
348
+ # decode chunk
349
+ tokens_start = max(
350
+ n_decoded_tokens
351
+ - self.streaming_lookback
352
+ - self.streaming_overlap_frames,
353
+ 0
354
+ )
355
+ tokens_end = (
356
+ n_decoded_tokens
357
+ + self.streaming_frames_per_chunk
358
+ + self.streaming_lookforward
359
+ + self.streaming_overlap_frames
360
+ )
361
+ sample_start = (
362
+ n_decoded_tokens - tokens_start
363
+ ) * self.hop_length
364
+ sample_end = (
365
+ sample_start
366
+ + (self.streaming_frames_per_chunk + 2 * self.streaming_overlap_frames) * self.hop_length
367
+ )
368
+ curr_codes = token_cache[tokens_start:tokens_end]
369
+ recon = self._decode("".join(curr_codes))
370
+ recon = recon[sample_start:sample_end]
371
+ audio_cache.append(recon)
372
+
373
+ # postprocess
374
+ processed_recon = _linear_overlap_add(
375
+ audio_cache, stride=self.streaming_stride_samples
376
+ )
377
+ new_samples_end = len(audio_cache) * self.streaming_stride_samples
378
+ processed_recon = processed_recon[
379
+ n_decoded_samples:new_samples_end
380
+ ]
381
+ n_decoded_samples = new_samples_end
382
+ n_decoded_tokens += self.streaming_frames_per_chunk
383
+ yield processed_recon
384
+
385
+ # final decoding handled separately as non-constant chunk size
386
+ remaining_tokens = len(token_cache) - n_decoded_tokens
387
+ if len(token_cache) > n_decoded_tokens:
388
+ tokens_start = max(
389
+ len(token_cache)
390
+ - (self.streaming_lookback + self.streaming_overlap_frames + remaining_tokens),
391
+ 0
392
+ )
393
+ sample_start = (
394
+ len(token_cache)
395
+ - tokens_start
396
+ - remaining_tokens
397
+ - self.streaming_overlap_frames
398
+ ) * self.hop_length
399
+ curr_codes = token_cache[tokens_start:]
400
+ recon = self._decode("".join(curr_codes))
401
+ recon = recon[sample_start:]
402
+ audio_cache.append(recon)
403
+
404
+ processed_recon = _linear_overlap_add(audio_cache, stride=self.streaming_stride_samples)
405
+ processed_recon = processed_recon[n_decoded_samples:]
406
+ yield processed_recon
407
+
408
+
409
+ # ============================================================================
410
+ # FastVieNeuTTS - GPU-optimized implementation
411
+ # Requires: LMDeploy with CUDA
412
+ # ============================================================================
413
+
414
+ class FastVieNeuTTS:
415
+ """
416
+ GPU-optimized VieNeu-TTS using LMDeploy TurbomindEngine.
417
+ """
418
+
419
+ def __init__(
420
+ self,
421
+ backbone_repo="pnnbao-ump/VieNeu-TTS",
422
+ backbone_device="cuda",
423
+ codec_repo="neuphonic/neucodec",
424
+ codec_device="cuda",
425
+ memory_util=0.3,
426
+ tp=1,
427
+ enable_prefix_caching=True,
428
+ quant_policy=8,
429
+ enable_triton=True,
430
+ max_batch_size=8,
431
+ ):
432
+ """
433
+ Initialize FastVieNeuTTS with LMDeploy backend and optimizations.
434
+
435
+ Args:
436
+ backbone_repo: Model repository
437
+ backbone_device: Device for backbone (must be CUDA)
438
+ codec_repo: Codec repository
439
+ codec_device: Device for codec
440
+ memory_util: GPU memory utilization (0.0-1.0)
441
+ tp: Tensor parallel size for multi-GPU
442
+ enable_prefix_caching: Enable prefix caching for faster batch processing
443
+ quant_policy: KV cache quantization (0=off, 8=int8, 4=int4)
444
+ enable_triton: Enable Triton compilation for codec
445
+ max_batch_size: Maximum batch size for inference (prevent GPU overload)
446
+ """
447
+
448
+ if backbone_device != "cuda" and not backbone_device.startswith("cuda:"):
449
+ raise ValueError("LMDeploy backend requires CUDA device")
450
+
451
+ # Constants
452
+ self.sample_rate = 24_000
453
+ self.max_context = 2048
454
+ self.hop_length = 480
455
+ self.streaming_overlap_frames = 1
456
+ self.streaming_frames_per_chunk = 50
457
+ self.streaming_lookforward = 5
458
+ self.streaming_lookback = 50
459
+ self.streaming_stride_samples = self.streaming_frames_per_chunk * self.hop_length
460
+
461
+ self.max_batch_size = max_batch_size
462
+
463
+ self._ref_cache = {}
464
+
465
+ self.stored_dict = defaultdict(dict)
466
+
467
+ # Flags
468
+ self._is_onnx_codec = False
469
+ self._triton_enabled = False
470
+
471
+ # Load models
472
+ self._load_backbone_lmdeploy(backbone_repo, memory_util, tp, enable_prefix_caching, quant_policy)
473
+ self._load_codec(codec_repo, codec_device, enable_triton)
474
+
475
+ self._warmup_model()
476
+
477
+ print("✅ FastVieNeuTTS with optimizations loaded successfully!")
478
+ print(f" Max batch size: {self.max_batch_size} (adjustable to prevent GPU overload)")
479
+
480
+ def _load_backbone_lmdeploy(self, repo, memory_util, tp, enable_prefix_caching, quant_policy):
481
+ """Load backbone using LMDeploy's TurbomindEngine"""
482
+ print(f"Loading backbone with LMDeploy from: {repo}")
483
+
484
+ try:
485
+ from lmdeploy import pipeline, TurbomindEngineConfig, GenerationConfig
486
+ except ImportError as e:
487
+ raise ImportError(
488
+ "Failed to import `lmdeploy`. "
489
+ "Please install it with: pip install lmdeploy"
490
+ ) from e
491
+
492
+ backend_config = TurbomindEngineConfig(
493
+ cache_max_entry_count=memory_util,
494
+ tp=tp,
495
+ enable_prefix_caching=enable_prefix_caching,
496
+ dtype='bfloat16',
497
+ quant_policy=quant_policy
498
+ )
499
+
500
+ self.backbone = pipeline(repo, backend_config=backend_config)
501
+
502
+ self.gen_config = GenerationConfig(
503
+ top_p=0.95,
504
+ top_k=50,
505
+ temperature=1.0,
506
+ max_new_tokens=1024,
507
+ repetition_penalty=1.0,
508
+ do_sample=True,
509
+ min_new_tokens=40,
510
+ min_p=0.1,
511
+ )
512
+
513
+ print(f" LMDeploy TurbomindEngine initialized")
514
+ print(f" - Memory util: {memory_util}")
515
+ print(f" - Tensor Parallel: {tp}")
516
+ print(f" - Prefix caching: {enable_prefix_caching}")
517
+ print(f" - KV quant: {quant_policy} ({'Enabled' if quant_policy > 0 else 'Disabled'})")
518
+
519
+ def _load_codec(self, codec_repo, codec_device, enable_triton):
520
+ """Load codec with optional Triton compilation"""
521
+ print(f"Loading codec from: {codec_repo} on {codec_device}")
522
+
523
+ match codec_repo:
524
+ case "neuphonic/neucodec":
525
+ self.codec = NeuCodec.from_pretrained(codec_repo)
526
+ self.codec.eval().to(codec_device)
527
+ case "neuphonic/distill-neucodec":
528
+ self.codec = DistillNeuCodec.from_pretrained(codec_repo)
529
+ self.codec.eval().to(codec_device)
530
+ case "neuphonic/neucodec-onnx-decoder":
531
+ if codec_device != "cpu":
532
+ raise ValueError("ONNX decoder only runs on CPU")
533
+ try:
534
+ from neucodec import NeuCodecOnnxDecoder
535
+ except ImportError as e:
536
+ raise ImportError(
537
+ "Failed to import ONNX decoder. "
538
+ "Ensure onnxruntime and neucodec >= 0.0.4 are installed."
539
+ ) from e
540
+ self.codec = NeuCodecOnnxDecoder.from_pretrained(codec_repo)
541
+ self._is_onnx_codec = True
542
+ case _:
543
+ raise ValueError(f"Unsupported codec repository: {codec_repo}")
544
+
545
+ if enable_triton and not self._is_onnx_codec and codec_device != "cpu":
546
+ self._triton_enabled = _compile_codec_with_triton(self.codec)
547
+
548
+ def _warmup_model(self):
549
+ """Warmup inference pipeline to reduce first-token latency"""
550
+ print("🔥 Warming up model...")
551
+ try:
552
+ dummy_codes = list(range(10))
553
+ dummy_prompt = self._format_prompt(dummy_codes, "warmup", "test")
554
+ _ = self.backbone([dummy_prompt], gen_config=self.gen_config, do_preprocess=False)
555
+ print(" ✅ Warmup complete")
556
+ except Exception as e:
557
+ print(f" ⚠️ Warmup failed (non-critical): {e}")
558
+
559
+ def encode_reference(self, ref_audio_path: str | Path):
560
+ """Encode reference audio to codes"""
561
+ wav, _ = librosa.load(ref_audio_path, sr=16000, mono=True)
562
+ wav_tensor = torch.from_numpy(wav).float().unsqueeze(0).unsqueeze(0)
563
+ with torch.no_grad():
564
+ ref_codes = self.codec.encode_code(audio_or_path=wav_tensor).squeeze(0).squeeze(0)
565
+ return ref_codes
566
+
567
+ def get_cached_reference(self, voice_name: str, audio_path: str, ref_text: str = None):
568
+ """
569
+ Get or create cached reference codes.
570
+
571
+ Args:
572
+ voice_name: Unique identifier for this voice
573
+ audio_path: Path to reference audio
574
+ ref_text: Optional reference text (stored with codes)
575
+
576
+ Returns:
577
+ ref_codes: Encoded reference codes
578
+ """
579
+ cache_key = f"{voice_name}_{audio_path}"
580
+
581
+ if cache_key not in self._ref_cache:
582
+ ref_codes = self.encode_reference(audio_path)
583
+ self._ref_cache[cache_key] = {
584
+ 'codes': ref_codes,
585
+ 'ref_text': ref_text
586
+ }
587
+
588
+ return self._ref_cache[cache_key]['codes']
589
+
590
+ def add_speaker(self, user_id: int, audio_file: str, ref_text: str):
591
+ """
592
+ Add a speaker to the stored dictionary for easy access.
593
+
594
+ Args:
595
+ user_id: Unique user ID
596
+ audio_file: Reference audio file path
597
+ ref_text: Reference text
598
+
599
+ Returns:
600
+ user_id: The user ID for use in streaming
601
+ """
602
+ codes = self.encode_reference(audio_file)
603
+
604
+ if isinstance(codes, torch.Tensor):
605
+ codes = codes.cpu().numpy()
606
+ if isinstance(codes, np.ndarray):
607
+ codes = codes.flatten().tolist()
608
+
609
+ self.stored_dict[f"{user_id}"]['codes'] = codes
610
+ self.stored_dict[f"{user_id}"]['ref_text'] = ref_text
611
+
612
+ return user_id
613
+
614
+ def _decode(self, codes: str):
615
+ """Decode speech tokens to audio waveform"""
616
+ speech_ids = [int(num) for num in re.findall(r"<\|speech_(\d+)\|>", codes)]
617
+
618
+ if len(speech_ids) == 0:
619
+ raise ValueError("No valid speech tokens found in output")
620
+
621
+ if self._is_onnx_codec:
622
+ codes = np.array(speech_ids, dtype=np.int32)[np.newaxis, np.newaxis, :]
623
+ recon = self.codec.decode_code(codes)
624
+ else:
625
+ with torch.no_grad():
626
+ codes = torch.tensor(speech_ids, dtype=torch.long)[None, None, :].to(
627
+ self.codec.device
628
+ )
629
+ recon = self.codec.decode_code(codes).cpu().numpy()
630
+
631
+ return recon[0, 0, :]
632
+
633
+ def _decode_batch(self, codes_list: list[str], max_workers: int = None):
634
+ """
635
+ Decode multiple code strings in parallel.
636
+
637
+ Args:
638
+ codes_list: List of code strings to decode
639
+ max_workers: Number of parallel workers (auto-tuned if None)
640
+
641
+ Returns:
642
+ List of decoded audio arrays
643
+ """
644
+ # Auto-tune workers based on GPU memory and batch size
645
+ if max_workers is None:
646
+ if torch.cuda.is_available():
647
+ gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
648
+ # 1 worker per 4GB VRAM, max 4 workers
649
+ max_workers = min(max(1, int(gpu_mem_gb / 4)), 4)
650
+ else:
651
+ max_workers = 2
652
+
653
+ # For small batches, use sequential to avoid overhead
654
+ if len(codes_list) <= 2:
655
+ return [self._decode(codes) for codes in codes_list]
656
+
657
+ # Parallel decoding with controlled workers
658
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
659
+ futures = [executor.submit(self._decode, codes) for codes in codes_list]
660
+ results = [f.result() for f in futures]
661
+ return results
662
+
663
+ def _format_prompt(self, ref_codes: list[int], ref_text: str, input_text: str) -> str:
664
+ """Format prompt for LMDeploy"""
665
+ ref_text_phones = phonemize_with_dict(ref_text)
666
+ input_text_phones = phonemize_with_dict(input_text)
667
+
668
+ codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes])
669
+
670
+ prompt = (
671
+ f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text_phones} {input_text_phones}"
672
+ f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}"
673
+ )
674
+
675
+ return prompt
676
+
677
+ def infer(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> np.ndarray:
678
+ """
679
+ Single inference.
680
+
681
+ Args:
682
+ text: Input text to synthesize
683
+ ref_codes: Encoded reference audio codes
684
+ ref_text: Reference text for reference audio
685
+
686
+ Returns:
687
+ Generated speech waveform as numpy array
688
+ """
689
+ if isinstance(ref_codes, torch.Tensor):
690
+ ref_codes = ref_codes.cpu().numpy()
691
+ if isinstance(ref_codes, np.ndarray):
692
+ ref_codes = ref_codes.flatten().tolist()
693
+
694
+ prompt = self._format_prompt(ref_codes, ref_text, text)
695
+
696
+ # Use LMDeploy pipeline for generation
697
+ responses = self.backbone([prompt], gen_config=self.gen_config, do_preprocess=False)
698
+ output_str = responses[0].text
699
+
700
+ # Decode to audio
701
+ wav = self._decode(output_str)
702
+
703
+ return wav
704
+
705
+ def infer_batch(self, texts: list[str], ref_codes: np.ndarray | torch.Tensor, ref_text: str, max_batch_size: int = None) -> list[np.ndarray]:
706
+ """
707
+ Batch inference for multiple texts.
708
+
709
+ Args:
710
+ texts: List of input texts to synthesize
711
+ ref_codes: Encoded reference audio codes
712
+ ref_text: Reference text for reference audio
713
+ max_batch_size: Maximum chunks to process at once (prevent GPU overload)
714
+
715
+ Returns:
716
+ List of generated speech waveforms
717
+ """
718
+ if max_batch_size is None:
719
+ max_batch_size = self.max_batch_size
720
+
721
+ if not isinstance(texts, list):
722
+ texts = [texts]
723
+
724
+ if isinstance(ref_codes, torch.Tensor):
725
+ ref_codes = ref_codes.cpu().numpy()
726
+ if isinstance(ref_codes, np.ndarray):
727
+ ref_codes = ref_codes.flatten().tolist()
728
+
729
+ all_wavs = []
730
+
731
+ # Process in smaller batches to avoid GPU OOM
732
+ for i in range(0, len(texts), max_batch_size):
733
+ batch_texts = texts[i:i+max_batch_size]
734
+
735
+ # Format prompts for this batch
736
+ prompts = [self._format_prompt(ref_codes, ref_text, text) for text in batch_texts]
737
+
738
+ # Batch generation with LMDeploy
739
+ responses = self.backbone(prompts, gen_config=self.gen_config, do_preprocess=False)
740
+
741
+ # Decode outputs (with smart parallelization)
742
+ batch_codes = [response.text for response in responses]
743
+
744
+ # Auto-tune parallel workers based on batch size
745
+ if len(batch_codes) > 3:
746
+ batch_wavs = self._decode_batch(batch_codes)
747
+ else:
748
+ # Sequential for small batches (less overhead)
749
+ batch_wavs = [self._decode(codes) for codes in batch_codes]
750
+
751
+ all_wavs.extend(batch_wavs)
752
+
753
+ # Clean up memory between batches
754
+ if i + max_batch_size < len(texts):
755
+ if torch.cuda.is_available():
756
+ torch.cuda.empty_cache()
757
+
758
+ return all_wavs
759
+
760
+ def infer_stream(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> Generator[np.ndarray, None, None]:
761
+ """
762
+ Streaming inference with low latency.
763
+
764
+ Args:
765
+ text: Input text to synthesize
766
+ ref_codes: Encoded reference audio codes
767
+ ref_text: Reference text for reference audio
768
+
769
+ Yields:
770
+ Audio chunks as numpy arrays
771
+ """
772
+ if isinstance(ref_codes, torch.Tensor):
773
+ ref_codes = ref_codes.cpu().numpy()
774
+ if isinstance(ref_codes, np.ndarray):
775
+ ref_codes = ref_codes.flatten().tolist()
776
+
777
+ prompt = self._format_prompt(ref_codes, ref_text, text)
778
+
779
+ audio_cache = []
780
+ token_cache = [f"<|speech_{idx}|>" for idx in ref_codes]
781
+ n_decoded_samples = 0
782
+ n_decoded_tokens = len(ref_codes)
783
+
784
+ for response in self.backbone.stream_infer([prompt], gen_config=self.gen_config, do_preprocess=False):
785
+ output_str = response.text
786
+
787
+ # Extract new tokens
788
+ new_tokens = output_str[len("".join(token_cache[len(ref_codes):])):] if len(token_cache) > len(ref_codes) else output_str
789
+
790
+ if new_tokens:
791
+ token_cache.append(new_tokens)
792
+
793
+ # Check if we have enough tokens to decode a chunk
794
+ if len(token_cache[n_decoded_tokens:]) >= self.streaming_frames_per_chunk + self.streaming_lookforward:
795
+
796
+ # Decode chunk with context
797
+ tokens_start = max(
798
+ n_decoded_tokens - self.streaming_lookback - self.streaming_overlap_frames,
799
+ 0
800
+ )
801
+ tokens_end = (
802
+ n_decoded_tokens
803
+ + self.streaming_frames_per_chunk
804
+ + self.streaming_lookforward
805
+ + self.streaming_overlap_frames
806
+ )
807
+ sample_start = (n_decoded_tokens - tokens_start) * self.hop_length
808
+ sample_end = (
809
+ sample_start
810
+ + (self.streaming_frames_per_chunk + 2 * self.streaming_overlap_frames) * self.hop_length
811
+ )
812
+
813
+ curr_codes = token_cache[tokens_start:tokens_end]
814
+ recon = self._decode("".join(curr_codes))
815
+ recon = recon[sample_start:sample_end]
816
+ audio_cache.append(recon)
817
+
818
+ # Overlap-add processing
819
+ processed_recon = _linear_overlap_add(
820
+ audio_cache, stride=self.streaming_stride_samples
821
+ )
822
+ new_samples_end = len(audio_cache) * self.streaming_stride_samples
823
+ processed_recon = processed_recon[n_decoded_samples:new_samples_end]
824
+ n_decoded_samples = new_samples_end
825
+ n_decoded_tokens += self.streaming_frames_per_chunk
826
+
827
+ yield processed_recon
828
+
829
+ # Final chunk
830
+ remaining_tokens = len(token_cache) - n_decoded_tokens
831
+ if remaining_tokens > 0:
832
+ tokens_start = max(
833
+ len(token_cache) - (self.streaming_lookback + self.streaming_overlap_frames + remaining_tokens),
834
+ 0
835
+ )
836
+ sample_start = (
837
+ len(token_cache) - tokens_start - remaining_tokens - self.streaming_overlap_frames
838
+ ) * self.hop_length
839
+
840
+ curr_codes = token_cache[tokens_start:]
841
+ recon = self._decode("".join(curr_codes))
842
+ recon = recon[sample_start:]
843
+ audio_cache.append(recon)
844
+
845
+ processed_recon = _linear_overlap_add(audio_cache, stride=self.streaming_stride_samples)
846
+ processed_recon = processed_recon[n_decoded_samples:]
847
+ yield processed_recon
848
+
849
+ def cleanup_memory(self):
850
+ """Clean up GPU memory"""
851
+ if torch.cuda.is_available():
852
+ torch.cuda.empty_cache()
853
+ gc.collect()
854
+ print("🧹 Memory cleaned up")
855
+
856
+ def get_optimization_stats(self) -> dict:
857
+ """
858
+ Get current optimization statistics.
859
+
860
+ Returns:
861
+ Dictionary with optimization info
862
+ """
863
+ return {
864
+ 'triton_enabled': self._triton_enabled,
865
+ 'cached_references': len(self._ref_cache),
866
+ 'active_sessions': len(self.stored_dict),
867
+ 'kv_quant': self.gen_config.__dict__.get('quant_policy', 0),
868
+ 'prefix_caching': True, # Always enabled in our config
869
+ }