pnnbao-ump commited on
Commit
d949d26
·
verified ·
1 Parent(s): 233afa6

Upload 16 files

Browse files
Dockerfile ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim-bullseye
2
+
3
+ # Cấu hình người dùng (Giữ lại từ template mặc định của HF)
4
+ RUN useradd -m -u 1000 user
5
+ USER user
6
+ ENV PATH="/home/user/.local/bin:$PATH"
7
+
8
+ # Thiết lập biến môi trường và thư mục làm việc
9
+ ENV PIP_NO_CACHE_DIR=1
10
+ ENV DEBIAN_FRONTEND=noninteractive
11
+ WORKDIR /app
12
+
13
+ # Chúng ta phải làm bước này trước khi cài Python libs
14
+ COPY --chown=user packages.txt ./
15
+ RUN apt-get update && \
16
+ apt-get install -y --no-install-recommends \
17
+ $(cat packages.txt | xargs) \
18
+ git \
19
+ libgl1 \
20
+ libsm6 \
21
+ libxext6 \
22
+ && apt-get clean && \
23
+ rm -rf /var/lib/apt/lists/*
24
+
25
+ # 3. Cài đặt Python Libs:
26
+ COPY --chown=user requirements.txt requirements.txt
27
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
28
+
29
+ # 4. Sao chép tất cả các tệp còn lại (bao gồm app.py, config.yaml, utils/, etc.)
30
+ # Đây là lệnh COPY chính cho toàn bộ mã nguồn của bạn.
31
+ COPY --chown=user . /app
32
+
33
+ # 5. Lệnh khởi chạy: Thay thế uvicorn bằng lệnh chạy ứng dụng Gradio của bạn (app.py)
34
+ CMD ["python", "app.py"]
35
+ # Lưu ý: Gradio thường chạy trên port 7860 mặc định, nên không cần chỉ định.
config.yaml ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ text_settings:
2
+ max_chars_per_chunk: 256
3
+ max_total_chars_streaming: 3000
4
+
5
+ backbone_configs:
6
+ "VieNeu-TTS (GPU)":
7
+ repo: pnnbao-ump/VieNeu-TTS
8
+ supports_streaming: false
9
+ description: Chất lượng cao nhất, yêu cầu GPU
10
+ "VieNeu-TTS-q8-gguf":
11
+ repo: pnnbao-ump/VieNeu-TTS-q8-gguf
12
+ supports_streaming: true
13
+ description: Cân bằng giữa chất lượng và tốc độ
14
+ "VieNeu-TTS-q4-gguf":
15
+ repo: pnnbao-ump/VieNeu-TTS-q4-gguf
16
+ supports_streaming: true
17
+ description: Nhẹ nhất, phù hợp CPU
18
+
19
+ codec_configs:
20
+ "NeuCodec (Standard)":
21
+ repo: neuphonic/neucodec
22
+ description: Codec chuẩn, tốc độ trung bình
23
+ use_preencoded: false
24
+ "NeuCodec ONNX (Fast CPU)":
25
+ repo: neuphonic/neucodec-onnx-decoder
26
+ description: Tối ưu cho CPU, cần pre-encoded codes
27
+ use_preencoded: true
28
+
29
+ voice_samples:
30
+ "Tuyên (nam miền Bắc)":
31
+ audio: ./sample/Tuyên (nam miền Bắc).wav
32
+ text: ./sample/Tuyên (nam miền Bắc).txt
33
+ codes: ./sample/Tuyên (nam miền Bắc).pt
34
+ "Vĩnh (nam miền Nam)":
35
+ audio: ./sample/Vĩnh (nam miền Nam).wav
36
+ text: ./sample/Vĩnh (nam miền Nam).txt
37
+ codes: ./sample/Vĩnh (nam miền Nam).pt
38
+ "Bình (nam miền Bắc)":
39
+ audio: ./sample/Bình (nam miền Bắc).wav
40
+ text: ./sample/Bình (nam miền Bắc).txt
41
+ codes: ./sample/Bình (nam miền Bắc).pt
42
+ "Nguyên (nam miền Nam)":
43
+ audio: ./sample/Nguyên (nam miền Nam).wav
44
+ text: ./sample/Nguyên (nam miền Nam).txt
45
+ codes: ./sample/Nguyên (nam miền Nam).pt
46
+ "Sơn (nam miền Nam)":
47
+ audio: ./sample/Sơn (nam miền Nam).wav
48
+ text: ./sample/Sơn (nam miền Nam).txt
49
+ codes: ./sample/Sơn (nam miền Nam).pt
50
+ "Đoan (nữ miền Nam)":
51
+ audio: ./sample/Đoan (nữ miền Nam).wav
52
+ text: ./sample/Đoan (nữ miền Nam).txt
53
+ codes: ./sample/Đoan (nữ miền Nam).pt
54
+ "Ngọc (nữ miền Bắc)":
55
+ audio: ./sample/Ngọc (nữ miền Bắc).wav
56
+ text: ./sample/Ngọc (nữ miền Bắc).txt
57
+ codes: ./sample/Ngọc (nữ miền Bắc).pt
58
+ "Ly (nữ miền Bắc)":
59
+ audio: ./sample/Ly (nữ miền Bắc).wav
60
+ text: ./sample/Ly (nữ miền Bắc).txt
61
+ codes: ./sample/Ly (nữ miền Bắc).pt
62
+ "Dung (nữ miền Nam)":
63
+ audio: ./sample/Dung (nữ miền Nam).wav
64
+ text: ./sample/Dung (nữ miền Nam).txt
65
+ codes: ./sample/Dung (nữ miền Nam).pt
gradio_app.py ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import soundfile as sf
3
+ import tempfile
4
+ import torch
5
+ from vieneu_tts import VieNeuTTS
6
+ import os
7
+ import time
8
+ import numpy as np
9
+ import re
10
+ from typing import Generator
11
+ import queue
12
+ import threading
13
+ import yaml
14
+ from utils.core_utils import split_text_into_chunks
15
+
16
+ print("⏳ Đang khởi động VieNeu-TTS...")
17
+
18
+ # --- CONSTANTS & CONFIG ---
19
+ CONFIG_PATH = os.path.join(os.path.dirname(__file__), "config.yaml")
20
+ try:
21
+ with open(CONFIG_PATH, "r", encoding="utf-8") as f:
22
+ _config = yaml.safe_load(f) or {}
23
+ except Exception as e:
24
+ raise RuntimeError(f"Không thể đọc config.yaml: {e}")
25
+
26
+ BACKBONE_CONFIGS = _config.get("backbone_configs", {})
27
+ CODEC_CONFIGS = _config.get("codec_configs", {})
28
+ VOICE_SAMPLES = _config.get("voice_samples", {})
29
+
30
+ _text_settings = _config.get("text_settings", {})
31
+ MAX_CHARS_PER_CHUNK = _text_settings.get("max_chars_per_chunk", 256)
32
+ MAX_TOTAL_CHARS_STREAMING = _text_settings.get("max_total_chars_streaming", 3000)
33
+
34
+ if not BACKBONE_CONFIGS or not CODEC_CONFIGS:
35
+ raise ValueError("config.yaml thiếu backbone_configs hoặc codec_configs")
36
+ if not VOICE_SAMPLES:
37
+ raise ValueError("config.yaml thiếu voice_samples")
38
+
39
+ # --- 1. MODEL CONFIGURATION ---
40
+ # Global model instance
41
+ tts = None
42
+ current_backbone = None
43
+ current_codec = None
44
+ model_loaded = False # ✨ THÊM STATE
45
+
46
+ def load_model(backbone_choice, codec_choice, device_choice):
47
+ """Load model with specified configuration"""
48
+ global tts, current_backbone, current_codec, model_loaded
49
+
50
+ # ✨ Trả về nhiều outputs để update UI ngay lập tức
51
+ yield (
52
+ "⏳ Đang tải model, vui lòng đợi...",
53
+ gr.update(interactive=False), # Disable nút "Bắt đầu"
54
+ gr.update(interactive=False) # Disable nút "Tải Model"
55
+ )
56
+
57
+ try:
58
+ backbone_config = BACKBONE_CONFIGS[backbone_choice]
59
+ codec_config = CODEC_CONFIGS[codec_choice]
60
+
61
+ # Determine devices
62
+ if device_choice == "Auto":
63
+ if "GGUF" in backbone_choice:
64
+ backbone_device = "gpu" if torch.cuda.is_available() else "cpu"
65
+ else:
66
+ backbone_device = "cuda" if torch.cuda.is_available() else "cpu"
67
+
68
+ if "ONNX" in codec_choice:
69
+ codec_device = "cpu"
70
+ else:
71
+ codec_device = "cuda" if torch.cuda.is_available() else "cpu"
72
+ else:
73
+ backbone_device = device_choice.lower()
74
+ codec_device = device_choice.lower()
75
+
76
+ if "ONNX" in codec_choice:
77
+ codec_device = "cpu"
78
+
79
+ if "GGUF" in backbone_choice and backbone_device == "cuda":
80
+ backbone_device = "gpu"
81
+
82
+ print(f"📦 Đang tải model...")
83
+ print(f" Backbone: {backbone_config['repo']} on {backbone_device}")
84
+ print(f" Codec: {codec_config['repo']} on {codec_device}")
85
+
86
+ tts = VieNeuTTS(
87
+ backbone_repo=backbone_config["repo"],
88
+ backbone_device=backbone_device,
89
+ codec_repo=codec_config["repo"],
90
+ codec_device=codec_device
91
+ )
92
+
93
+ current_backbone = backbone_choice
94
+ current_codec = codec_choice
95
+ model_loaded = True # ✨ Đánh dấu đã load xong
96
+
97
+ note_for_llama_cpp = "\n⚠️ Lưu ý: Nếu bạn chọn gpu (cuda) cho bản gguf cần phải cài đặt đúng theo hướng dẫn ở link này để tận dụng được GPU: https://pypi.org/project/llama-cpp-python/"
98
+ preencoded_note = "\n⚠️ Codec ONNX cần sử dụng pre-encoded codes (.pt files)" if codec_config['use_preencoded'] else ""
99
+
100
+ success_msg = (
101
+ f"✅ Model đã tải thành công!\n\n"
102
+ f"🦜 Model Device: {backbone_device.upper()}{note_for_llama_cpp}\n\n"
103
+ f"🎵 Codec Device: {codec_device.upper()}{preencoded_note}"
104
+ )
105
+
106
+ yield (
107
+ success_msg,
108
+ gr.update(interactive=True), # ✨ Enable nút "Bắt đầu"
109
+ gr.update(interactive=True) # Enable nút "Tải Model"
110
+ )
111
+
112
+ except Exception as e:
113
+ import traceback
114
+ traceback.print_exc()
115
+ model_loaded = False
116
+
117
+ yield (
118
+ f"❌ Lỗi khi tải model: {str(e)}",
119
+ gr.update(interactive=False), # Vẫn disable nút "Bắt đầu"
120
+ gr.update(interactive=True) # Enable nút "Tải Model" để thử lại
121
+ )
122
+
123
+ # --- 2. DATA & HELPERS ---
124
+ GGUF_ALLOWED_VOICES = [
125
+ "Vĩnh (nam miền Nam)",
126
+ "Bình (nam miền Bắc)",
127
+ "Ngọc (nữ miền Bắc)",
128
+ "Dung (nữ miền Nam)",
129
+ ]
130
+
131
+ def get_voice_options(backbone_choice: str):
132
+ """Filter voice options: GGUF only shows the 4 allowed voices."""
133
+ if "gguf" in backbone_choice:
134
+ return [v for v in GGUF_ALLOWED_VOICES if v in VOICE_SAMPLES]
135
+ return list(VOICE_SAMPLES.keys())
136
+
137
+ def update_voice_dropdown(backbone_choice: str, current_voice: str):
138
+ options = get_voice_options(backbone_choice)
139
+ new_value = current_voice if current_voice in options else (options[0] if options else None)
140
+ return gr.update(choices=options, value=new_value)
141
+
142
+ # --- 3. CORE LOGIC FUNCTIONS ---
143
+ def load_reference_info(voice_choice):
144
+ if voice_choice in VOICE_SAMPLES:
145
+ audio_path = VOICE_SAMPLES[voice_choice]["audio"]
146
+ text_path = VOICE_SAMPLES[voice_choice]["text"]
147
+ try:
148
+ if os.path.exists(text_path):
149
+ with open(text_path, "r", encoding="utf-8") as f:
150
+ ref_text = f.read()
151
+ return audio_path, ref_text
152
+ else:
153
+ return audio_path, "⚠️ Không tìm thấy file text mẫu."
154
+ except Exception as e:
155
+ return None, f"❌ Lỗi: {str(e)}"
156
+ return None, ""
157
+
158
+ def synthesize_speech(text, voice_choice, custom_audio, custom_text, mode_tab, generation_mode):
159
+ """Synthesis with model check"""
160
+ global tts, current_backbone, current_codec, model_loaded
161
+
162
+ # ✨ Kiểm tra model đã load chưa
163
+ if not model_loaded or tts is None:
164
+ yield None, "⚠️ Vui lòng tải model trước!"
165
+ return
166
+
167
+ if not text or text.strip() == "":
168
+ yield None, "⚠️ Vui lòng nhập văn bản!"
169
+ return
170
+
171
+ raw_text = text.strip()
172
+
173
+ codec_config = CODEC_CONFIGS[current_codec]
174
+ use_preencoded = codec_config['use_preencoded']
175
+
176
+ # Setup Reference
177
+ if mode_tab == "custom_mode":
178
+ if custom_audio is None or not custom_text:
179
+ yield None, "⚠️ Thiếu Audio hoặc Text mẫu custom."
180
+ return
181
+ ref_audio_path = custom_audio
182
+ ref_text_raw = custom_text
183
+ ref_codes_path = None
184
+ else:
185
+ if voice_choice not in VOICE_SAMPLES:
186
+ yield None, "⚠️ Vui lòng chọn giọng mẫu."
187
+ return
188
+ ref_audio_path = VOICE_SAMPLES[voice_choice]["audio"]
189
+ ref_text_path = VOICE_SAMPLES[voice_choice]["text"]
190
+ ref_codes_path = VOICE_SAMPLES[voice_choice]["codes"]
191
+
192
+ if not os.path.exists(ref_audio_path):
193
+ yield None, "❌ Không tìm thấy file audio mẫu."
194
+ return
195
+
196
+ with open(ref_text_path, "r", encoding="utf-8") as f:
197
+ ref_text_raw = f.read()
198
+
199
+ yield None, "📄 Đang xử lý Reference..."
200
+
201
+ # Encode reference
202
+ try:
203
+ if use_preencoded and ref_codes_path and os.path.exists(ref_codes_path):
204
+ ref_codes = torch.load(ref_codes_path, map_location="cpu")
205
+ else:
206
+ ref_codes = tts.encode_reference(ref_audio_path)
207
+
208
+ if isinstance(ref_codes, torch.Tensor):
209
+ ref_codes = ref_codes.cpu().numpy()
210
+ except Exception as e:
211
+ yield None, f"❌ Lỗi xử lý reference: {e}"
212
+ return
213
+
214
+ text_chunks = split_text_into_chunks(raw_text, max_chars=MAX_CHARS_PER_CHUNK)
215
+ total_chunks = len(text_chunks)
216
+
217
+ # === STANDARD MODE ===
218
+ if generation_mode == "Standard (Một lần)":
219
+ yield None, f"🚀 Bắt đầu tổng hợp chế độ Standard ({total_chunks} đoạn)..."
220
+
221
+ all_audio_segments = []
222
+ sr = 24000
223
+ silence_pad = np.zeros(int(sr * 0.15), dtype=np.float32)
224
+
225
+ start_time = time.time()
226
+
227
+ try:
228
+ for i, chunk in enumerate(text_chunks):
229
+ yield None, f"⏳ Đang xử lý đoạn {i+1}/{total_chunks}..."
230
+
231
+ chunk_wav = tts.infer(chunk, ref_codes, ref_text_raw)
232
+
233
+ if chunk_wav is not None and len(chunk_wav) > 0:
234
+ all_audio_segments.append(chunk_wav)
235
+ if i < total_chunks - 1:
236
+ all_audio_segments.append(silence_pad)
237
+
238
+ if not all_audio_segments:
239
+ yield None, "❌ Không sinh được audio nào."
240
+ return
241
+
242
+ yield None, "💾 Đang ghép file và lưu..."
243
+
244
+ final_wav = np.concatenate(all_audio_segments)
245
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
246
+ sf.write(tmp.name, final_wav, sr)
247
+ output_path = tmp.name
248
+
249
+ process_time = time.time() - start_time
250
+ yield output_path, f"✅ Hoàn tất! (Tổng thời gian: {process_time:.2f}s)"
251
+
252
+ except Exception as e:
253
+ import traceback
254
+ traceback.print_exc()
255
+ yield None, f"❌ Lỗi Standard Mode: {str(e)}"
256
+ return
257
+
258
+ # === STREAMING MODE ===
259
+ else:
260
+ sr = 24000
261
+ crossfade_samples = int(sr * 0.03)
262
+ audio_queue = queue.Queue(maxsize=100)
263
+ PRE_BUFFER_SIZE = 3
264
+
265
+ end_event = threading.Event()
266
+ error_event = threading.Event()
267
+ error_msg = ""
268
+
269
+ def producer_thread():
270
+ nonlocal error_msg
271
+ try:
272
+ previous_tail = None
273
+ chunk_count = 0
274
+
275
+ for i, chunk_text in enumerate(text_chunks):
276
+ stream_gen = tts.infer_stream(chunk_text, ref_codes, ref_text_raw)
277
+
278
+ for part_idx, audio_part in enumerate(stream_gen):
279
+ if audio_part is None or len(audio_part) == 0:
280
+ continue
281
+
282
+ if previous_tail is not None and len(previous_tail) > 0:
283
+ overlap = min(len(previous_tail), len(audio_part), crossfade_samples)
284
+ if overlap > 0:
285
+ fade_out = np.linspace(1.0, 0.0, overlap, dtype=np.float32)
286
+ fade_in = np.linspace(0.0, 1.0, overlap, dtype=np.float32)
287
+
288
+ blended = (audio_part[:overlap] * fade_in +
289
+ previous_tail[-overlap:] * fade_out)
290
+
291
+ processed = np.concatenate([
292
+ previous_tail[:-overlap] if len(previous_tail) > overlap else np.array([]),
293
+ blended,
294
+ audio_part[overlap:]
295
+ ])
296
+ else:
297
+ processed = np.concatenate([previous_tail, audio_part])
298
+
299
+ tail_size = min(crossfade_samples, len(processed))
300
+ previous_tail = processed[-tail_size:].copy()
301
+ output_chunk = processed[:-tail_size] if len(processed) > tail_size else processed
302
+ else:
303
+ tail_size = min(crossfade_samples, len(audio_part))
304
+ previous_tail = audio_part[-tail_size:].copy()
305
+ output_chunk = audio_part[:-tail_size] if len(audio_part) > tail_size else audio_part
306
+
307
+ if len(output_chunk) > 0:
308
+ audio_queue.put((sr, output_chunk))
309
+ chunk_count += 1
310
+
311
+ if previous_tail is not None and len(previous_tail) > 0:
312
+ audio_queue.put((sr, previous_tail))
313
+
314
+ except Exception as e:
315
+ import traceback
316
+ traceback.print_exc()
317
+ error_msg = str(e)
318
+ error_event.set()
319
+ finally:
320
+ end_event.set()
321
+ audio_queue.put(None)
322
+
323
+ threading.Thread(target=producer_thread, daemon=True).start()
324
+
325
+ yield (sr, np.zeros(int(sr * 0.05))), "🔄 Đang buffering..."
326
+
327
+ pre_buffer = []
328
+ while len(pre_buffer) < PRE_BUFFER_SIZE:
329
+ try:
330
+ item = audio_queue.get(timeout=5.0)
331
+ if item is None:
332
+ break
333
+ pre_buffer.append(item)
334
+ except queue.Empty:
335
+ if error_event.is_set():
336
+ yield None, f"❌ Lỗi: {error_msg}"
337
+ return
338
+ break
339
+
340
+ full_audio_buffer = []
341
+ for sr, audio_data in pre_buffer:
342
+ full_audio_buffer.append(audio_data)
343
+ yield (sr, audio_data), "🔊 Đang phát..."
344
+
345
+ while True:
346
+ try:
347
+ item = audio_queue.get(timeout=0.05)
348
+ if item is None:
349
+ break
350
+ sr, audio_data = item
351
+ full_audio_buffer.append(audio_data)
352
+ yield (sr, audio_data), "🔊 Đang phát..."
353
+ except queue.Empty:
354
+ if error_event.is_set():
355
+ yield None, f"❌ Lỗi: {error_msg}"
356
+ break
357
+ if end_event.is_set() and audio_queue.empty():
358
+ break
359
+ continue
360
+
361
+ if full_audio_buffer:
362
+ final_wav = np.concatenate(full_audio_buffer)
363
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
364
+ sf.write(tmp.name, final_wav, sr)
365
+ yield tmp.name, "✅ Hoàn tất Streaming!"
366
+
367
+ # --- 4. UI SETUP ---
368
+ theme = gr.themes.Ocean(
369
+ primary_hue="indigo",
370
+ secondary_hue="cyan",
371
+ neutral_hue="slate",
372
+ font=[gr.themes.GoogleFont('Inter'), 'ui-sans-serif', 'system-ui'],
373
+ ).set(
374
+ button_primary_background_fill="linear-gradient(90deg, #6366f1 0%, #0ea5e9 100%)",
375
+ button_primary_background_fill_hover="linear-gradient(90deg, #4f46e5 0%, #0284c7 100%)",
376
+ )
377
+
378
+ css = """
379
+ .container { max-width: 1400px; margin: auto; }
380
+ .header-box {
381
+ text-align: center;
382
+ margin-bottom: 25px;
383
+ padding: 25px;
384
+ background: linear-gradient(135deg, #0f172a 0%, #1e293b 100%);
385
+ border-radius: 12px;
386
+ color: white;
387
+ }
388
+ .header-title {
389
+ font-size: 2.5rem;
390
+ font-weight: 800;
391
+ /* Bỏ hiệu ứng tô màu gradient ở đây và chuyển nó sang thẻ con */
392
+ }
393
+ .gradient-text {
394
+ background: -webkit-linear-gradient(45deg, #60A5FA, #22D3EE);
395
+ -webkit-background-clip: text;
396
+ -webkit-text-fill-color: transparent;
397
+ }
398
+ .header-icon {
399
+ color: white; /* Ép màu trắng */
400
+ }
401
+ .status-box {
402
+ font-weight: bold;
403
+ text-align: center;
404
+ border: none;
405
+ background: transparent;
406
+ }
407
+ .model-card {
408
+ background: linear-gradient(135deg, #f8fafc 0%, #e2e8f0 100%);
409
+ border-radius: 12px;
410
+ padding: 20px;
411
+ margin-bottom: 25px;
412
+ border: 1px solid #cbd5e1;
413
+ }
414
+ .model-card-title {
415
+ font-size: 1.1rem;
416
+ font-weight: 700;
417
+ color: #1e293b;
418
+ margin-bottom: 12px;
419
+ display: flex;
420
+ align-items: center;
421
+ gap: 8px;
422
+ }
423
+ .model-card-content {
424
+ display: flex;
425
+ flex-wrap: wrap;
426
+ justify-content: center;
427
+ align-items: center;
428
+ gap: 15px;
429
+ font-size: 0.9rem;
430
+ text-align: center;
431
+ }
432
+ .model-card-item {
433
+ display: flex;
434
+ align-items: center;
435
+ justify-content: center;
436
+ gap: 6px;
437
+ color: #475569;
438
+ }
439
+ .model-card-link {
440
+ color: #3b82f6;
441
+ text-decoration: none;
442
+ font-weight: 500;
443
+ transition: color 0.2s;
444
+ }
445
+ .model-card-link:hover {
446
+ color: #2563eb;
447
+ text-decoration: underline;
448
+ }
449
+ """
450
+
451
+ EXAMPLES_LIST = [
452
+ ["Về miền Tây không chỉ để ngắm nhìn sông nước hữu tình, mà còn để cảm nhận tấm chân tình của người dân nơi đây.", "Vĩnh (nam miền Nam)"],
453
+ ["Hà Nội những ngày vào thu mang một vẻ đẹp trầm mặc và cổ kính đến lạ thường.", "Bình (nam miền Bắc)"],
454
+ ]
455
+
456
+ with gr.Blocks(theme=theme, css=css, title="VieNeu-TTS") as demo:
457
+ with gr.Column(elem_classes="container"):
458
+ gr.HTML("""
459
+ <div class="header-box">
460
+ <h1 class="header-title">
461
+ <span class="header-icon">🦜</span>
462
+ <span class="gradient-text">VieNeu-TTS Studio</span>
463
+ </h1>
464
+ <div class="model-card-content">
465
+ <div class="model-card-item">
466
+ <strong>Models:</strong>
467
+ <a href="https://huggingface.co/pnnbao-ump/VieNeu-TTS" target="_blank" class="model-card-link">VieNeu-TTS</a>
468
+ <span>•</span>
469
+ <a href="https://huggingface.co/pnnbao-ump/VieNeu-TTS-q4-gguf" target="_blank" class="model-card-link">Q4-GGUF</a>
470
+ <span>•</span>
471
+ <a href="https://huggingface.co/pnnbao-ump/VieNeu-TTS-q8-gguf" target="_blank" class="model-card-link">Q8-GGUF</a>
472
+ </div>
473
+ <div class="model-card-item">
474
+ <strong>Repository:</strong>
475
+ <a href="https://github.com/pnnbao97/VieNeu-TTS" target="_blank" class="model-card-link">GitHub</a>
476
+ </div>
477
+ <div class="model-card-item">
478
+ <strong>Tác giả:</strong>
479
+ <span>Phạm Nguyễn Ngọc Bảo</span>
480
+ </div>
481
+ </div>
482
+ </div>
483
+ """)
484
+
485
+ # --- CONFIGURATION ---
486
+ with gr.Group():
487
+ with gr.Row():
488
+ backbone_select = gr.Dropdown(list(BACKBONE_CONFIGS.keys()), value="VieNeu-TTS (GPU)", label="🦜 Backbone")
489
+ codec_select = gr.Dropdown(list(CODEC_CONFIGS.keys()), value="NeuCodec (Standard)", label="🎵 Codec")
490
+ device_choice = gr.Radio(["Auto", "CPU", "CUDA"], value="Auto", label="🖥️ Device")
491
+
492
+
493
+ btn_load = gr.Button("🔄 Tải Model", variant="primary")
494
+ model_status = gr.Markdown("⏳ Chưa tải model.")
495
+
496
+ with gr.Row(elem_classes="container"):
497
+ # --- INPUT ---
498
+ with gr.Column(scale=3):
499
+ text_input = gr.Textbox(
500
+ label=f"Văn bản (Streaming hỗ trợ tới {MAX_TOTAL_CHARS_STREAMING} ký tự, chia chunk {MAX_CHARS_PER_CHUNK} ký tự)",
501
+ lines=4,
502
+ value="Hà Nội, trái tim của Việt Nam, là một thành phố ngàn năm văn hiến với bề dày lịch sử và văn hóa độc đáo. Bước chân trên những con phố cổ kính quanh Hồ Hoàn Kiếm, du khách như được du hành ngược thời gian, chiêm ngưỡng kiến trúc Pháp cổ điển hòa quyện với nét kiến trúc truyền thống Việt Nam. Mỗi con phố trong khu phố cổ mang một tên gọi đặc trưng, phản ánh nghề thủ công truyền thống từng thịnh hành nơi đây như phố Hàng Bạc, Hàng Đào, Hàng Mã. Ẩm thực Hà Nội cũng là một điểm nhấn đặc biệt, từ tô phở nóng hổi buổi sáng, bún chả thơm lừng trưa hè, đến chè Thái ngọt ngào chiều thu. Những món ăn dân dã này đã trở thành biểu tượng của văn hóa ẩm thực Việt, được cả thế giới yêu mến. Người Hà Nội nổi tiếng với tính cách hiền hòa, lịch thiệp nhưng cũng rất cầu toàn trong từng chi tiết nhỏ, từ cách pha trà sen cho đến cách chọn hoa sen tây để thưởng trà.",
503
+ )
504
+
505
+ with gr.Tabs() as tabs:
506
+ with gr.TabItem("👤 Preset", id="preset_mode"):
507
+ initial_voices = get_voice_options("GGUF Q4")
508
+ default_voice = initial_voices[0] if initial_voices else None
509
+ voice_select = gr.Dropdown(initial_voices, value=default_voice, label="Giọng mẫu")
510
+
511
+ with gr.TabItem("🎙️ Custom", id="custom_mode"):
512
+ custom_audio = gr.Audio(label="File mẫu (.wav)", type="filepath")
513
+ custom_text = gr.Textbox(label="Lời thoại mẫu")
514
+
515
+ generation_mode = gr.Radio(
516
+ ["Standard (Một lần)"],
517
+ value="Standard (Một lần)",
518
+ label="Chế độ sinh"
519
+ )
520
+
521
+ current_mode = gr.Textbox(visible=False, value="preset_mode")
522
+
523
+ # ✨ NÚT BẮT ĐẦU - MẶC ĐỊNH DISABLE
524
+ btn_generate = gr.Button("🎵 Bắt đầu", variant="primary", size="lg", interactive=False)
525
+
526
+ # --- OUTPUT ---
527
+ with gr.Column(scale=2):
528
+ audio_output = gr.Audio(
529
+ label="Kết quả",
530
+ type="filepath",
531
+ autoplay=True,
532
+ show_download_button=True
533
+ )
534
+ status_output = gr.Textbox(label="Trạng thái", elem_classes="status-box")
535
+
536
+ # --- EVENT HANDLERS ---
537
+ def update_info(backbone):
538
+ return f"Streaming: {'✅' if BACKBONE_CONFIGS[backbone]['supports_streaming'] else '❌'}"
539
+
540
+ backbone_select.change(update_info, backbone_select, model_status)
541
+ backbone_select.change(update_voice_dropdown, [backbone_select, voice_select], voice_select)
542
+
543
+ tabs.children[0].select(lambda: "preset_mode", outputs=current_mode)
544
+ tabs.children[1].select(lambda: "custom_mode", outputs=current_mode)
545
+
546
+ # ✨ CẬP NHẬT EVENT HANDLER CHO NÚT LOAD
547
+ btn_load.click(
548
+ fn=load_model,
549
+ inputs=[backbone_select, codec_select, device_choice],
550
+ outputs=[model_status, btn_generate, btn_load] # Update cả 3 components
551
+ )
552
+
553
+ btn_generate.click(
554
+ fn=synthesize_speech,
555
+ inputs=[text_input, voice_select, custom_audio, custom_text, current_mode, generation_mode],
556
+ outputs=[audio_output, status_output]
557
+ )
558
+
559
+ if __name__ == "__main__":
560
+ demo.queue().launch()
utils/__init__.py ADDED
File without changes
utils/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (128 Bytes). View file
 
utils/__pycache__/core_utils.cpython-312.pyc ADDED
Binary file (1.9 kB). View file
 
utils/__pycache__/normalize_text.cpython-312.pyc ADDED
Binary file (23.9 kB). View file
 
utils/__pycache__/phonemize_text.cpython-312.pyc ADDED
Binary file (5.74 kB). View file
 
utils/core_utils.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import List
3
+
4
+ def split_text_into_chunks(text: str, max_chars: int = 256) -> List[str]:
5
+ """
6
+ Split raw text into chunks no longer than max_chars.
7
+ Preference is given to sentence boundaries; otherwise falls back to word-based splitting.
8
+ """
9
+ sentences = re.split(r"(?<=[\.\!\?\…])\s+", text.strip())
10
+ chunks: List[str] = []
11
+ buffer = ""
12
+
13
+ def flush_buffer():
14
+ nonlocal buffer
15
+ if buffer:
16
+ chunks.append(buffer.strip())
17
+ buffer = ""
18
+
19
+ for sentence in sentences:
20
+ sentence = sentence.strip()
21
+ if not sentence:
22
+ continue
23
+
24
+ if len(sentence) <= max_chars:
25
+ candidate = f"{buffer} {sentence}".strip() if buffer else sentence
26
+ if len(candidate) <= max_chars:
27
+ buffer = candidate
28
+ else:
29
+ flush_buffer()
30
+ buffer = sentence
31
+ continue
32
+
33
+ flush_buffer()
34
+ words = sentence.split()
35
+ current = ""
36
+ for word in words:
37
+ candidate = f"{current} {word}".strip() if current else word
38
+ if len(candidate) > max_chars and current:
39
+ chunks.append(current.strip())
40
+ current = word
41
+ else:
42
+ current = candidate
43
+ if current:
44
+ chunks.append(current.strip())
45
+
46
+ flush_buffer()
47
+ return [chunk for chunk in chunks if chunk]
utils/normalize_text.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ class VietnameseTTSNormalizer:
4
+ """
5
+ A text normalizer for Vietnamese Text-to-Speech systems.
6
+ Converts numbers, dates, units, and special characters into readable Vietnamese text.
7
+ """
8
+
9
+ def __init__(self):
10
+ self.units = {
11
+ 'km': 'ki lô mét', 'dm': 'đê xi mét', 'cm': 'xen ti mét',
12
+ 'mm': 'mi li mét', 'nm': 'na nô mét', 'µm': 'mic rô mét',
13
+ 'μm': 'mic rô mét', 'm': 'mét',
14
+
15
+ 'kg': 'ki lô gam', 'g': 'gam', 'mg': 'mi li gam',
16
+
17
+ 'km²': 'ki lô mét vuông', 'km2': 'ki lô mét vuông',
18
+ 'm²': 'mét vuông', 'm2': 'mét vuông',
19
+ 'cm²': 'xen ti mét vuông', 'cm2': 'xen ti mét vuông',
20
+ 'mm²': 'mi li mét vuông', 'mm2': 'mi li mét vuông',
21
+ 'ha': 'héc ta',
22
+
23
+ 'km³': 'ki lô mét khối', 'km3': 'ki lô mét khối',
24
+ 'm³': 'mét khối', 'm3': 'mét khối',
25
+ 'cm³': 'xen ti mét khối', 'cm3': 'xen ti mét khối',
26
+ 'mm³': 'mi li mét khối', 'mm3': 'mi li mét khối',
27
+ 'l': 'lít', 'dl': 'đê xi lít', 'ml': 'mi li lít', 'hl': 'héc tô lít',
28
+
29
+ 'v': 'vôn', 'kv': 'ki lô vôn', 'mv': 'mi li vôn',
30
+ 'a': 'am pe', 'ma': 'mi li am pe', 'ka': 'ki lô am pe',
31
+ 'w': 'oát', 'kw': 'ki lô oát', 'mw': 'mê ga oát', 'gw': 'gi ga oát',
32
+ 'kwh': 'ki lô oát giờ', 'mwh': 'mê ga oát giờ', 'wh': 'oát giờ',
33
+ 'ω': 'ôm', 'ohm': 'ôm', 'kω': 'ki lô ôm', 'mω': 'mê ga ôm',
34
+
35
+ 'hz': 'héc', 'khz': 'ki lô héc', 'mhz': 'mê ga héc', 'ghz': 'gi ga héc',
36
+
37
+ 'pa': 'pát cal', 'kpa': 'ki lô pát cal', 'mpa': 'mê ga pát cal',
38
+ 'bar': 'ba', 'mbar': 'mi li ba', 'atm': 'át mốt phia', 'psi': 'pi ét xai',
39
+
40
+ 'j': 'giun', 'kj': 'ki lô giun',
41
+ 'cal': 'ca lo', 'kcal': 'ki lô ca lo',
42
+ }
43
+
44
+ self.digits = ['không', 'một', 'hai', 'ba', 'bốn',
45
+ 'năm', 'sáu', 'bảy', 'tám', 'chín']
46
+
47
+ def normalize(self, text):
48
+ """Main normalization pipeline."""
49
+ text = text.lower()
50
+ text = self._normalize_temperature(text)
51
+ text = self._normalize_currency(text)
52
+ text = self._normalize_percentage(text)
53
+ text = self._normalize_units(text)
54
+ text = self._normalize_time(text)
55
+ text = self._normalize_date(text)
56
+ text = self._normalize_phone(text)
57
+ text = self._normalize_numbers(text)
58
+ text = self._number_to_words(text)
59
+ text = self._normalize_special_chars(text)
60
+ text = self._normalize_whitespace(text)
61
+ return text
62
+
63
+ def _normalize_temperature(self, text):
64
+ """Convert temperature notation to words."""
65
+ text = re.sub(r'-(\d+(?:[.,]\d+)?)\s*°\s*c\b', r'âm \1 độ xê', text, flags=re.IGNORECASE)
66
+ text = re.sub(r'-(\d+(?:[.,]\d+)?)\s*°\s*f\b', r'âm \1 độ ép', text, flags=re.IGNORECASE)
67
+ text = re.sub(r'(\d+(?:[.,]\d+)?)\s*°\s*c\b', r'\1 độ xê', text, flags=re.IGNORECASE)
68
+ text = re.sub(r'(\d+(?:[.,]\d+)?)\s*°\s*f\b', r'\1 độ ép', text, flags=re.IGNORECASE)
69
+ text = re.sub(r'°', ' độ ', text)
70
+ return text
71
+
72
+ def _normalize_currency(self, text):
73
+ """Convert currency notation to words."""
74
+ def decimal_currency(match):
75
+ whole = match.group(1)
76
+ decimal = match.group(2)
77
+ unit = match.group(3)
78
+ decimal_words = ' '.join([self.digits[int(d)] for d in decimal])
79
+ unit_map = {'k': 'nghìn', 'm': 'triệu', 'b': 'tỷ'}
80
+ unit_word = unit_map.get(unit.lower(), unit)
81
+ return f"{whole} phẩy {decimal_words} {unit_word}"
82
+
83
+ text = re.sub(r'(\d+)[.,](\d+)\s*([kmb])\b', decimal_currency, text, flags=re.IGNORECASE)
84
+ text = re.sub(r'(\d+)\s*k\b', r'\1 nghìn', text, flags=re.IGNORECASE)
85
+ text = re.sub(r'(\d+)\s*m\b', r'\1 triệu', text, flags=re.IGNORECASE)
86
+ text = re.sub(r'(\d+)\s*b\b', r'\1 tỷ', text, flags=re.IGNORECASE)
87
+ text = re.sub(r'(\d+(?:[.,]\d+)?)\s*đ\b', r'\1 đồng', text)
88
+ text = re.sub(r'(\d+(?:[.,]\d+)?)\s*vnd\b', r'\1 đồng', text, flags=re.IGNORECASE)
89
+ text = re.sub(r'\$\s*(\d+(?:[.,]\d+)?)', r'\1 đô la', text)
90
+ text = re.sub(r'(\d+(?:[.,]\d+)?)\s*\$', r'\1 đô la', text)
91
+ return text
92
+
93
+ def _normalize_percentage(self, text):
94
+ """Convert percentage to words."""
95
+ text = re.sub(r'(\d+(?:[.,]\d+)?)\s*%', r'\1 phần trăm', text)
96
+ return text
97
+
98
+ def _normalize_units(self, text):
99
+ """Convert measurement units to words."""
100
+ def expand_compound_with_number(match):
101
+ number = match.group(1)
102
+ unit1 = match.group(2).lower()
103
+ unit2 = match.group(3).lower()
104
+ full_unit1 = self.units.get(unit1, unit1)
105
+ full_unit2 = self.units.get(unit2, unit2)
106
+ return f"{number} {full_unit1} trên {full_unit2}"
107
+
108
+ def expand_compound_without_number(match):
109
+ unit1 = match.group(1).lower()
110
+ unit2 = match.group(2).lower()
111
+ full_unit1 = self.units.get(unit1, unit1)
112
+ full_unit2 = self.units.get(unit2, unit2)
113
+ return f"{full_unit1} trên {full_unit2}"
114
+
115
+ text = re.sub(r'(\d+(?:[.,]\d+)?)\s*([a-zA-Zμµ²³°]+)/([a-zA-Zμµ²³°0-9]+)\b',
116
+ expand_compound_with_number, text)
117
+ text = re.sub(r'\b([a-zA-Zμµ²³°]+)/([a-zA-Zμµ²³°0-9]+)\b',
118
+ expand_compound_without_number, text)
119
+
120
+ sorted_units = sorted(self.units.items(), key=lambda x: len(x[0]), reverse=True)
121
+ for unit, full_name in sorted_units:
122
+ pattern = r'(\d+(?:[.,]\d+)?)\s*' + re.escape(unit) + r'\b'
123
+ text = re.sub(pattern, rf'\1 {full_name}', text, flags=re.IGNORECASE)
124
+
125
+ for unit, full_name in sorted_units:
126
+ if any(c in unit for c in '²³°'):
127
+ pattern = r'\b' + re.escape(unit) + r'\b'
128
+ text = re.sub(pattern, full_name, text, flags=re.IGNORECASE)
129
+
130
+ return text
131
+
132
+ def _normalize_time(self, text):
133
+ """Convert time notation to words with validation."""
134
+
135
+ def validate_and_convert_time(match):
136
+ """Validate time components before converting."""
137
+ groups = match.groups()
138
+
139
+ # HH:MM:SS format
140
+ if len(groups) == 3:
141
+ hour, minute, second = groups
142
+ hour_int, minute_int, second_int = int(hour), int(minute), int(second)
143
+
144
+ # Validate ranges
145
+ if not (0 <= hour_int <= 23):
146
+ return match.group(0) # Return original if invalid
147
+ if not (0 <= minute_int <= 59):
148
+ return match.group(0)
149
+ if not (0 <= second_int <= 59):
150
+ return match.group(0)
151
+
152
+ return f"{hour} giờ {minute} phút {second} giây"
153
+
154
+ # HH:MM or HHhMM format
155
+ elif len(groups) == 2:
156
+ hour, minute = groups
157
+ hour_int, minute_int = int(hour), int(minute)
158
+
159
+ # Validate ranges
160
+ if not (0 <= hour_int <= 23):
161
+ return match.group(0)
162
+ if not (0 <= minute_int <= 59):
163
+ return match.group(0)
164
+
165
+ return f"{hour} giờ {minute} phút"
166
+
167
+ # HHh format
168
+ else:
169
+ hour = groups[0]
170
+ hour_int = int(hour)
171
+
172
+ if not (0 <= hour_int <= 23):
173
+ return match.group(0)
174
+
175
+ return f"{hour} giờ"
176
+
177
+ # Apply patterns with validation
178
+ text = re.sub(r'(\d{1,2}):(\d{2}):(\d{2})', validate_and_convert_time, text)
179
+ text = re.sub(r'(\d{1,2}):(\d{2})', validate_and_convert_time, text)
180
+ text = re.sub(r'(\d{1,2})h(\d{2})', validate_and_convert_time, text)
181
+ text = re.sub(r'(\d{1,2})h\b', validate_and_convert_time, text)
182
+
183
+ return text
184
+
185
+ def _normalize_date(self, text):
186
+ """Convert date notation to words with validation."""
187
+
188
+ def is_valid_date(day, month, year):
189
+ """Check if date components are valid."""
190
+ day, month, year = int(day), int(month), int(year)
191
+
192
+ # Basic range checks
193
+ if not (1 <= day <= 31):
194
+ return False
195
+ if not (1 <= month <= 12):
196
+ return False
197
+
198
+ return True
199
+
200
+ def date_to_text(match):
201
+ day, month, year = match.groups()
202
+ if is_valid_date(day, month, year):
203
+ return f"ngày {day} tháng {month} năm {year}"
204
+ return match.group(0) # Return original if invalid
205
+
206
+ def date_iso_to_text(match):
207
+ year, month, day = match.groups()
208
+ if is_valid_date(day, month, year):
209
+ return f"ngày {day} tháng {month} năm {year}"
210
+ return match.group(0)
211
+
212
+ def date_short_year(match):
213
+ day, month, year = match.groups()
214
+ full_year = f"20{year}" if int(year) < 50 else f"19{year}"
215
+ if is_valid_date(day, month, full_year):
216
+ return f"ngày {day} tháng {month} năm {full_year}"
217
+ return match.group(0)
218
+
219
+ # Apply patterns with validation
220
+ text = re.sub(r'\bngày\s+(\d{1,2})[/\-](\d{1,2})[/\-](\d{4})\b',
221
+ lambda m: date_to_text(m).replace('ngày ngày', 'ngày'), text)
222
+ text = re.sub(r'\bngày\s+(\d{1,2})[/\-](\d{1,2})[/\-](\d{2})\b',
223
+ lambda m: date_short_year(m).replace('ngày ngày', 'ngày'), text)
224
+ text = re.sub(r'\b(\d{4})-(\d{1,2})-(\d{1,2})\b', date_iso_to_text, text)
225
+ text = re.sub(r'\b(\d{1,2})[/\-](\d{1,2})[/\-](\d{4})\b', date_to_text, text)
226
+ text = re.sub(r'\b(\d{1,2})[/\-](\d{1,2})[/\-](\d{2})\b', date_short_year, text)
227
+
228
+ return text
229
+
230
+ def _normalize_phone(self, text):
231
+ """Convert phone numbers to digit-by-digit reading."""
232
+ def phone_to_text(match):
233
+ phone = match.group(0)
234
+ phone = re.sub(r'[^\d]', '', phone)
235
+
236
+ if phone.startswith('84') and len(phone) >= 10:
237
+ phone = '0' + phone[2:]
238
+
239
+ if 10 <= len(phone) <= 11:
240
+ words = [self.digits[int(d)] for d in phone]
241
+ return ' '.join(words) + ' '
242
+
243
+ return match.group(0)
244
+
245
+ text = re.sub(r'(\+84|84)[\s\-\.]?\d[\d\s\-\.]{7,}', phone_to_text, text)
246
+ text = re.sub(r'\b0\d[\d\s\-\.]{8,}', phone_to_text, text)
247
+ return text
248
+
249
+ def _normalize_numbers(self, text):
250
+ text = re.sub(r'(\d+(?:[,.]\d+)?)%', lambda m: f'{m.group(1)} phần trăm', text)
251
+ # 1. Xóa dấu thousand separator trước
252
+ text = re.sub(r'(\d{1,3})(?:\.(\d{3}))+', lambda m: m.group(0).replace('.', ''), text)
253
+
254
+ # 2. Chuyển số thập phân thành chữ
255
+ def decimal_to_words(match):
256
+ whole = match.group(1)
257
+ decimal = match.group(2)
258
+ decimal_words = ' '.join([self.digits[int(d)] for d in decimal])
259
+ separator = 'phẩy' if ',' in match.group(0) else 'chấm'
260
+ return f"{whole} {separator} {decimal_words}"
261
+
262
+ # 2a. Dấu phẩy
263
+ text = re.sub(r'(\d+),(\d+)', decimal_to_words, text)
264
+ # 2b. Dấu chấm (1-2 chữ số thập phân)
265
+ text = re.sub(r'(\d+)\.(\d{1,2})\b', decimal_to_words, text)
266
+
267
+ return text
268
+
269
+ def _read_two_digits(self, n):
270
+ """Read two-digit numbers in Vietnamese."""
271
+ if n < 10:
272
+ return self.digits[n]
273
+ elif n == 10:
274
+ return "mười"
275
+ elif n < 20:
276
+ if n == 15:
277
+ return "mười lăm"
278
+ return f"mười {self.digits[n % 10]}"
279
+ else:
280
+ tens = n // 10
281
+ ones = n % 10
282
+ if ones == 0:
283
+ return f"{self.digits[tens]} mươi"
284
+ elif ones == 1:
285
+ return f"{self.digits[tens]} mươi mốt"
286
+ elif ones == 5:
287
+ return f"{self.digits[tens]} mươi lăm"
288
+ else:
289
+ return f"{self.digits[tens]} mươi {self.digits[ones]}"
290
+
291
+ def _read_three_digits(self, n):
292
+ """Read three-digit numbers in Vietnamese."""
293
+ if n < 100:
294
+ return self._read_two_digits(n)
295
+
296
+ hundreds = n // 100
297
+ remainder = n % 100
298
+ result = f"{self.digits[hundreds]} trăm"
299
+
300
+ if remainder == 0:
301
+ return result
302
+ elif remainder < 10:
303
+ result += f" lẻ {self.digits[remainder]}"
304
+ else:
305
+ result += f" {self._read_two_digits(remainder)}"
306
+
307
+ return result
308
+
309
+ def _convert_number_to_words(self, num):
310
+ """Convert a number to Vietnamese words."""
311
+ if num == 0:
312
+ return "không"
313
+
314
+ if num < 0:
315
+ return f"âm {self._convert_number_to_words(-num)}"
316
+
317
+ if num >= 1000000000:
318
+ billion = num // 1000000000
319
+ remainder = num % 1000000000
320
+ result = f"{self._read_three_digits(billion)} tỷ"
321
+ if remainder > 0:
322
+ result += f" {self._convert_number_to_words(remainder)}"
323
+ return result
324
+
325
+ elif num >= 1000000:
326
+ million = num // 1000000
327
+ remainder = num % 1000000
328
+ result = f"{self._read_three_digits(million)} triệu"
329
+ if remainder > 0:
330
+ result += f" {self._convert_number_to_words(remainder)}"
331
+ return result
332
+
333
+ elif num >= 1000:
334
+ thousand = num // 1000
335
+ remainder = num % 1000
336
+ result = f"{self._read_three_digits(thousand)} nghìn"
337
+ if remainder > 0:
338
+ if remainder < 100:
339
+ result += f" không trăm {self._read_two_digits(remainder)}"
340
+ else:
341
+ result += f" {self._read_three_digits(remainder)}"
342
+ return result
343
+
344
+ else:
345
+ return self._read_three_digits(num)
346
+
347
+ def _number_to_words(self, text):
348
+ """Convert all remaining numbers to words."""
349
+ def convert_number(match):
350
+ num = int(match.group(0))
351
+ return self._convert_number_to_words(num)
352
+
353
+ text = re.sub(r'\b\d+\b', convert_number, text)
354
+ return text
355
+
356
+ def _normalize_special_chars(self, text):
357
+ """Handle special characters."""
358
+ text = text.replace('&', ' và ')
359
+ text = text.replace('+', ' cộng ')
360
+ text = text.replace('=', ' bằng ')
361
+ text = text.replace('#', ' thăng ')
362
+ text = re.sub(r'[\[\]\(\)\{\}]', ' ', text)
363
+ text = re.sub(r'\s+[-–—]+\s+', ' ', text)
364
+ text = re.sub(r'\.{2,}', ' ', text)
365
+ text = re.sub(r'\s+\.\s+', ' ', text)
366
+ text = re.sub(r'[^\w\sàáảãạăắằẳẵặâấầẩẫậèéẻẽẹêếềểễệìíỉĩịòóỏõọôốồổỗộơớờởỡợùúủũụưứừửữựỳýỷỹỵđ.,!?;:@%]', ' ', text)
367
+ return text
368
+
369
+ def _normalize_whitespace(self, text):
370
+ """Normalize whitespace."""
371
+ text = re.sub(r'\s+', ' ', text)
372
+ text = text.strip()
373
+ return text
374
+
375
+
376
+ if __name__ == "__main__":
377
+ normalizer = VietnameseTTSNormalizer()
378
+
379
+ test_texts = [
380
+ "Giá 2.500.000đ (giảm 50%), mua trước 14h30 ngày 15/12/2025",
381
+ "Liên hệ: 0912-345-678 hoặc email@example.com",
382
+ "Tốc độ 120km/h, trọng lượng 75kg",
383
+ "Nhiệt độ 36,5°C, độ ẩm 80%",
384
+ "Số pi = 3,14159",
385
+ "Giá trị tăng 2.5M, đạt 10B",
386
+ "Nhiệt độ -15°C vào mùa đông",
387
+ "Điện áp 220V, công suất 2.5kW, tần số 50Hz",
388
+ "Tôi đi lấy l nước về nhà",
389
+ "Cần 5l nước cho công thức này",
390
+ "Vận tốc ánh sáng 299792km/s",
391
+ "Mật độ dân số 450 người/km2",
392
+ "Công suất 100 W/m2",
393
+ "Hôm nay 2025-01-15",
394
+ "Gọi +84 912 345 678",
395
+ "Nhiệt độ 25°C lúc 14:30:45",
396
+ "Ngày 15/12/25",
397
+ "Giá 3.140.159",
398
+ ]
399
+
400
+ print("=" * 80)
401
+ print("VIETNAMESE TTS NORMALIZATION TEST")
402
+ print("=" * 80)
403
+
404
+ for text in test_texts:
405
+ print(f"\n📝 Input: {text}")
406
+ normalized = normalizer.normalize(text)
407
+ print(f"🎵 Output: {normalized}")
408
+ print("-" * 80)
utils/phoneme_dict.json ADDED
The diff for this file is too large to render. See raw diff
 
utils/phonemize_text.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import platform
4
+ import glob
5
+ from phonemizer import phonemize
6
+ from phonemizer.backend.espeak.espeak import EspeakWrapper
7
+ from utils.normalize_text import VietnameseTTSNormalizer
8
+
9
+ # Configuration
10
+ PHONEME_DICT_PATH = os.getenv(
11
+ 'PHONEME_DICT_PATH',
12
+ os.path.join(os.path.dirname(__file__), "phoneme_dict.json")
13
+ )
14
+
15
+ def load_phoneme_dict(path=PHONEME_DICT_PATH):
16
+ """Load phoneme dictionary from JSON file."""
17
+ try:
18
+ with open(path, "r", encoding="utf-8") as f:
19
+ return json.load(f)
20
+ except FileNotFoundError:
21
+ raise FileNotFoundError(
22
+ f"Phoneme dictionary not found at {path}. "
23
+ "Please create it or set PHONEME_DICT_PATH environment variable."
24
+ )
25
+
26
+ def setup_espeak_library():
27
+ """Configure eSpeak library path based on operating system."""
28
+ system = platform.system()
29
+
30
+ if system == "Windows":
31
+ _setup_windows_espeak()
32
+ elif system == "Linux":
33
+ _setup_linux_espeak()
34
+ elif system == "Darwin":
35
+ _setup_macos_espeak()
36
+ else:
37
+ raise OSError(
38
+ f"Unsupported OS: {system}. "
39
+ "Only Windows, Linux, and macOS are supported."
40
+ )
41
+
42
+ def _setup_windows_espeak():
43
+ """Setup eSpeak for Windows."""
44
+ default_path = r"C:\Program Files\eSpeak NG\libespeak-ng.dll"
45
+ if os.path.exists(default_path):
46
+ EspeakWrapper.set_library(default_path)
47
+ else:
48
+ raise FileNotFoundError(
49
+ f"eSpeak library not found at {default_path}. "
50
+ "Please install eSpeak NG from: https://github.com/espeak-ng/espeak-ng/releases"
51
+ )
52
+
53
+ def _setup_linux_espeak():
54
+ """Setup eSpeak for Linux."""
55
+ search_patterns = [
56
+ "/usr/lib/x86_64-linux-gnu/libespeak-ng.so*",
57
+ "/usr/lib/x86_64-linux-gnu/libespeak.so*",
58
+ "/usr/lib/libespeak-ng.so*",
59
+ "/usr/lib64/libespeak-ng.so*",
60
+ "/usr/local/lib/libespeak-ng.so*",
61
+ ]
62
+
63
+ for pattern in search_patterns:
64
+ matches = glob.glob(pattern)
65
+ if matches:
66
+ EspeakWrapper.set_library(sorted(matches, key=len)[0])
67
+ return
68
+
69
+ raise RuntimeError(
70
+ "eSpeak NG library not found. Install with:\n"
71
+ " Ubuntu/Debian: sudo apt-get install espeak-ng\n"
72
+ " Fedora: sudo dnf install espeak-ng\n"
73
+ " Arch: sudo pacman -S espeak-ng\n"
74
+ "See: https://github.com/pnnbao97/VieNeu-TTS/issues/5"
75
+ )
76
+
77
+ def _setup_macos_espeak():
78
+ """Setup eSpeak for macOS."""
79
+ espeak_lib = os.environ.get('PHONEMIZER_ESPEAK_LIBRARY')
80
+
81
+ paths_to_check = [
82
+ espeak_lib,
83
+ "/opt/homebrew/lib/libespeak-ng.dylib", # Apple Silicon
84
+ "/usr/local/lib/libespeak-ng.dylib", # Intel
85
+ "/opt/local/lib/libespeak-ng.dylib", # MacPorts
86
+ ]
87
+
88
+ for path in paths_to_check:
89
+ if path and os.path.exists(path):
90
+ EspeakWrapper.set_library(path)
91
+ return
92
+
93
+ raise FileNotFoundError(
94
+ "eSpeak library not found. Install with:\n"
95
+ " brew install espeak-ng\n"
96
+ "Or set: export PHONEMIZER_ESPEAK_LIBRARY=/path/to/libespeak-ng.dylib"
97
+ )
98
+
99
+ # Initialize
100
+ try:
101
+ setup_espeak_library()
102
+ phoneme_dict = load_phoneme_dict()
103
+ normalizer = VietnameseTTSNormalizer()
104
+ except Exception as e:
105
+ print(f"Initialization error: {e}")
106
+ raise
107
+
108
+ def phonemize_text(text: str) -> str:
109
+ """Convert text to phonemes using phonemizer."""
110
+ text = normalizer.normalize(text)
111
+ return phonemize(
112
+ text,
113
+ language="vi",
114
+ backend="espeak",
115
+ preserve_punctuation=True,
116
+ with_stress=True,
117
+ language_switch="remove-flags"
118
+ )
119
+
120
+ def phonemize_with_dict(text: str, phoneme_dict=phoneme_dict) -> str:
121
+ """Phonemize text with dictionary lookup."""
122
+ text = normalizer.normalize(text)
123
+ words = text.split()
124
+ result = []
125
+
126
+ for word in words:
127
+ if word in phoneme_dict:
128
+ phone_word = phoneme_dict[word]
129
+ else:
130
+ try:
131
+ phone_word = phonemize(
132
+ word,
133
+ language='vi',
134
+ backend='espeak',
135
+ preserve_punctuation=True,
136
+ with_stress=True,
137
+ language_switch='remove-flags'
138
+ )
139
+
140
+ if word.lower().startswith('r'):
141
+ phone_word = 'ɹ' + phone_word[1:]
142
+
143
+ phoneme_dict[word] = phone_word
144
+ except Exception as e:
145
+ print(f"Warning: Could not phonemize '{word}': {e}")
146
+ phone_word = word
147
+
148
+ result.append(phone_word)
149
+
150
+ return ' '.join(result)
vieneu_tts/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .vieneu_tts import VieNeuTTS
2
+
3
+ __all__ = ["VieNeuTTS"]
4
+
vieneu_tts/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (207 Bytes). View file
 
vieneu_tts/__pycache__/vieneu_tts.cpython-312.pyc ADDED
Binary file (17 kB). View file
 
vieneu_tts/vieneu_tts.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Generator
3
+ import librosa
4
+ import numpy as np
5
+ import torch
6
+ from neucodec import NeuCodec, DistillNeuCodec
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+ from utils.phonemize_text import phonemize_with_dict
9
+ import re
10
+
11
+ def _linear_overlap_add(frames: list[np.ndarray], stride: int) -> np.ndarray:
12
+ # original impl --> https://github.com/facebookresearch/encodec/blob/main/encodec/utils.py
13
+ assert len(frames)
14
+ dtype = frames[0].dtype
15
+ shape = frames[0].shape[:-1]
16
+
17
+ total_size = 0
18
+ for i, frame in enumerate(frames):
19
+ frame_end = stride * i + frame.shape[-1]
20
+ total_size = max(total_size, frame_end)
21
+
22
+ sum_weight = np.zeros(total_size, dtype=dtype)
23
+ out = np.zeros(*shape, total_size, dtype=dtype)
24
+
25
+ offset: int = 0
26
+ for frame in frames:
27
+ frame_length = frame.shape[-1]
28
+ t = np.linspace(0, 1, frame_length + 2, dtype=dtype)[1:-1]
29
+ weight = np.abs(0.5 - (t - 0.5))
30
+
31
+ out[..., offset : offset + frame_length] += weight * frame
32
+ sum_weight[offset : offset + frame_length] += weight
33
+ offset += stride
34
+ assert sum_weight.min() > 0
35
+ return out / sum_weight
36
+
37
+ class VieNeuTTS:
38
+ def __init__(
39
+ self,
40
+ backbone_repo="pnnbao-ump/VieNeu-TTS",
41
+ backbone_device="cpu",
42
+ codec_repo="neuphonic/neucodec",
43
+ codec_device="cpu",
44
+ ):
45
+
46
+ # Constants
47
+ self.sample_rate = 24_000
48
+ self.max_context = 2048
49
+ self.hop_length = 480
50
+ self.streaming_overlap_frames = 1
51
+ self.streaming_frames_per_chunk = 25
52
+ self.streaming_lookforward = 5
53
+ self.streaming_lookback = 50
54
+ self.streaming_stride_samples = self.streaming_frames_per_chunk * self.hop_length
55
+
56
+ # ggml & onnx flags
57
+ self._is_quantized_model = False
58
+ self._is_onnx_codec = False
59
+
60
+ # HF tokenizer
61
+ self.tokenizer = None
62
+
63
+ # Load models
64
+ self._load_backbone(backbone_repo, backbone_device)
65
+ self._load_codec(codec_repo, codec_device)
66
+
67
+ def _load_backbone(self, backbone_repo, backbone_device):
68
+ print(f"Loading backbone from: {backbone_repo} on {backbone_device} ...")
69
+
70
+ if backbone_repo.lower().endswith("gguf") or "gguf" in backbone_repo.lower():
71
+ try:
72
+ from llama_cpp import Llama
73
+ except ImportError as e:
74
+ raise ImportError(
75
+ "Failed to import `llama_cpp`. "
76
+ "Please install it with:\n"
77
+ " pip install llama-cpp-python"
78
+ ) from e
79
+ self.backbone = Llama.from_pretrained(
80
+ repo_id=backbone_repo,
81
+ filename="*.gguf",
82
+ verbose=False,
83
+ n_gpu_layers=-1 if backbone_device == "gpu" else 0,
84
+ n_ctx=self.max_context,
85
+ mlock=True,
86
+ flash_attn=True if backbone_device == "gpu" else False,
87
+ )
88
+ self._is_quantized_model = True
89
+
90
+ else:
91
+ self.tokenizer = AutoTokenizer.from_pretrained(backbone_repo)
92
+ self.backbone = AutoModelForCausalLM.from_pretrained(backbone_repo).to(
93
+ torch.device(backbone_device)
94
+ )
95
+
96
+ def _load_codec(self, codec_repo, codec_device):
97
+ print(f"Loading codec from: {codec_repo} on {codec_device} ...")
98
+ match codec_repo:
99
+ case "neuphonic/neucodec":
100
+ self.codec = NeuCodec.from_pretrained(codec_repo)
101
+ self.codec.eval().to(codec_device)
102
+ case "neuphonic/distill-neucodec":
103
+ self.codec = DistillNeuCodec.from_pretrained(codec_repo)
104
+ self.codec.eval().to(codec_device)
105
+ case "neuphonic/neucodec-onnx-decoder":
106
+ if codec_device != "cpu":
107
+ raise ValueError("Onnx decoder only currently runs on CPU.")
108
+ try:
109
+ from neucodec import NeuCodecOnnxDecoder
110
+ except ImportError as e:
111
+ raise ImportError(
112
+ "Failed to import the onnx decoder."
113
+ "Ensure you have onnxruntime installed as well as neucodec >= 0.0.4."
114
+ ) from e
115
+ self.codec = NeuCodecOnnxDecoder.from_pretrained(codec_repo)
116
+ self._is_onnx_codec = True
117
+ case _:
118
+ raise ValueError(f"Unsupported codec repository: {codec_repo}")
119
+
120
+ def infer(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> np.ndarray:
121
+ """
122
+ Perform inference to generate speech from text using the TTS model and reference audio.
123
+
124
+ Args:
125
+ text (str): Input text to be converted to speech.
126
+ ref_codes (np.ndarray | torch.tensor): Encoded reference.
127
+ ref_text (str): Reference text for reference audio. Defaults to None.
128
+ Returns:
129
+ np.ndarray: Generated speech waveform.
130
+ """
131
+
132
+ # Generate tokens
133
+ if self._is_quantized_model:
134
+ output_str = self._infer_ggml(ref_codes, ref_text, text)
135
+ else:
136
+ prompt_ids = self._apply_chat_template(ref_codes, ref_text, text)
137
+ output_str = self._infer_torch(prompt_ids)
138
+
139
+ # Decode
140
+ wav = self._decode(output_str)
141
+
142
+ return wav
143
+
144
+ def infer_stream(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> Generator[np.ndarray, None, None]:
145
+ """
146
+ Perform streaming inference to generate speech from text using the TTS model and reference audio.
147
+
148
+ Args:
149
+ text (str): Input text to be converted to speech.
150
+ ref_codes (np.ndarray | torch.tensor): Encoded reference.
151
+ ref_text (str): Reference text for reference audio. Defaults to None.
152
+ Yields:
153
+ np.ndarray: Generated speech waveform.
154
+ """
155
+
156
+ if self._is_quantized_model:
157
+ return self._infer_stream_ggml(ref_codes, ref_text, text)
158
+ else:
159
+ raise NotImplementedError("Streaming is not implemented for the torch backend!")
160
+
161
+ def encode_reference(self, ref_audio_path: str | Path):
162
+ wav, _ = librosa.load(ref_audio_path, sr=16000, mono=True)
163
+ wav_tensor = torch.from_numpy(wav).float().unsqueeze(0).unsqueeze(0) # [1, 1, T]
164
+ with torch.no_grad():
165
+ ref_codes = self.codec.encode_code(audio_or_path=wav_tensor).squeeze(0).squeeze(0)
166
+ return ref_codes
167
+
168
+ def _decode(self, codes: str):
169
+ """Decode speech tokens to audio waveform."""
170
+ # Extract speech token IDs using regex
171
+ speech_ids = [int(num) for num in re.findall(r"<\|speech_(\d+)\|>", codes)]
172
+
173
+ if len(speech_ids) == 0:
174
+ raise ValueError(
175
+ "No valid speech tokens found in the output. "
176
+ "The model may not have generated proper speech tokens."
177
+ )
178
+
179
+ # Onnx decode
180
+ if self._is_onnx_codec:
181
+ codes = np.array(speech_ids, dtype=np.int32)[np.newaxis, np.newaxis, :]
182
+ recon = self.codec.decode_code(codes)
183
+ # Torch decode
184
+ else:
185
+ with torch.no_grad():
186
+ codes = torch.tensor(speech_ids, dtype=torch.long)[None, None, :].to(
187
+ self.codec.device
188
+ )
189
+ recon = self.codec.decode_code(codes).cpu().numpy()
190
+
191
+ return recon[0, 0, :]
192
+
193
+ def _apply_chat_template(self, ref_codes: list[int], ref_text: str, input_text: str) -> list[int]:
194
+ input_text = phonemize_with_dict(ref_text) + " " + phonemize_with_dict(input_text)
195
+
196
+ speech_replace = self.tokenizer.convert_tokens_to_ids("<|SPEECH_REPLACE|>")
197
+ speech_gen_start = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_START|>")
198
+ text_replace = self.tokenizer.convert_tokens_to_ids("<|TEXT_REPLACE|>")
199
+ text_prompt_start = self.tokenizer.convert_tokens_to_ids("<|TEXT_PROMPT_START|>")
200
+ text_prompt_end = self.tokenizer.convert_tokens_to_ids("<|TEXT_PROMPT_END|>")
201
+
202
+ input_ids = self.tokenizer.encode(input_text, add_special_tokens=False)
203
+ chat = """user: Convert the text to speech:<|TEXT_REPLACE|>\nassistant:<|SPEECH_REPLACE|>"""
204
+ ids = self.tokenizer.encode(chat)
205
+
206
+ text_replace_idx = ids.index(text_replace)
207
+ ids = (
208
+ ids[:text_replace_idx]
209
+ + [text_prompt_start]
210
+ + input_ids
211
+ + [text_prompt_end]
212
+ + ids[text_replace_idx + 1 :] # noqa
213
+ )
214
+
215
+ speech_replace_idx = ids.index(speech_replace)
216
+ codes_str = "".join([f"<|speech_{i}|>" for i in ref_codes])
217
+ codes = self.tokenizer.encode(codes_str, add_special_tokens=False)
218
+ ids = ids[:speech_replace_idx] + [speech_gen_start] + list(codes)
219
+
220
+ return ids
221
+
222
+ def _infer_torch(self, prompt_ids: list[int]) -> str:
223
+ prompt_tensor = torch.tensor(prompt_ids).unsqueeze(0).to(self.backbone.device)
224
+ speech_end_id = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>")
225
+ with torch.no_grad():
226
+ output_tokens = self.backbone.generate(
227
+ prompt_tensor,
228
+ max_length=self.max_context,
229
+ eos_token_id=speech_end_id,
230
+ do_sample=True,
231
+ temperature=1.0,
232
+ top_k=50,
233
+ use_cache=True,
234
+ min_new_tokens=50,
235
+ )
236
+ input_length = prompt_tensor.shape[-1]
237
+ output_str = self.tokenizer.decode(
238
+ output_tokens[0, input_length:].cpu().numpy().tolist(), add_special_tokens=False
239
+ )
240
+ return output_str
241
+
242
+ def _infer_ggml(self, ref_codes: list[int], ref_text: str, input_text: str) -> str:
243
+ ref_text = phonemize_with_dict(ref_text)
244
+ input_text = phonemize_with_dict(input_text)
245
+
246
+ codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes])
247
+ prompt = (
248
+ f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text} {input_text}"
249
+ f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}"
250
+ )
251
+ output = self.backbone(
252
+ prompt,
253
+ max_tokens=self.max_context,
254
+ temperature=1.0,
255
+ top_k=50,
256
+ stop=["<|SPEECH_GENERATION_END|>"],
257
+ )
258
+ output_str = output["choices"][0]["text"]
259
+ return output_str
260
+
261
+ def _infer_stream_ggml(self, ref_codes: torch.Tensor, ref_text: str, input_text: str) -> Generator[np.ndarray, None, None]:
262
+ ref_text = phonemize_with_dict(ref_text)
263
+ input_text = phonemize_with_dict(input_text)
264
+
265
+ codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes])
266
+ prompt = (
267
+ f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text} {input_text}"
268
+ f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}"
269
+ )
270
+
271
+ audio_cache: list[np.ndarray] = []
272
+ token_cache: list[str] = [f"<|speech_{idx}|>" for idx in ref_codes]
273
+ n_decoded_samples: int = 0
274
+ n_decoded_tokens: int = len(ref_codes)
275
+
276
+ for item in self.backbone(
277
+ prompt,
278
+ max_tokens=self.max_context,
279
+ temperature=1.0,
280
+ top_k=50,
281
+ stop=["<|SPEECH_GENERATION_END|>"],
282
+ stream=True
283
+ ):
284
+ output_str = item["choices"][0]["text"]
285
+ token_cache.append(output_str)
286
+
287
+ if len(token_cache[n_decoded_tokens:]) >= self.streaming_frames_per_chunk + self.streaming_lookforward:
288
+
289
+ # decode chunk
290
+ tokens_start = max(
291
+ n_decoded_tokens
292
+ - self.streaming_lookback
293
+ - self.streaming_overlap_frames,
294
+ 0
295
+ )
296
+ tokens_end = (
297
+ n_decoded_tokens
298
+ + self.streaming_frames_per_chunk
299
+ + self.streaming_lookforward
300
+ + self.streaming_overlap_frames
301
+ )
302
+ sample_start = (
303
+ n_decoded_tokens - tokens_start
304
+ ) * self.hop_length
305
+ sample_end = (
306
+ sample_start
307
+ + (self.streaming_frames_per_chunk + 2 * self.streaming_overlap_frames) * self.hop_length
308
+ )
309
+ curr_codes = token_cache[tokens_start:tokens_end]
310
+ recon = self._decode("".join(curr_codes))
311
+ recon = recon[sample_start:sample_end]
312
+ audio_cache.append(recon)
313
+
314
+ # postprocess
315
+ processed_recon = _linear_overlap_add(
316
+ audio_cache, stride=self.streaming_stride_samples
317
+ )
318
+ new_samples_end = len(audio_cache) * self.streaming_stride_samples
319
+ processed_recon = processed_recon[
320
+ n_decoded_samples:new_samples_end
321
+ ]
322
+ n_decoded_samples = new_samples_end
323
+ n_decoded_tokens += self.streaming_frames_per_chunk
324
+ yield processed_recon
325
+
326
+ # final decoding handled separately as non-constant chunk size
327
+ remaining_tokens = len(token_cache) - n_decoded_tokens
328
+ if len(token_cache) > n_decoded_tokens:
329
+ tokens_start = max(
330
+ len(token_cache)
331
+ - (self.streaming_lookback + self.streaming_overlap_frames + remaining_tokens),
332
+ 0
333
+ )
334
+ sample_start = (
335
+ len(token_cache)
336
+ - tokens_start
337
+ - remaining_tokens
338
+ - self.streaming_overlap_frames
339
+ ) * self.hop_length
340
+ curr_codes = token_cache[tokens_start:]
341
+ recon = self._decode("".join(curr_codes))
342
+ recon = recon[sample_start:]
343
+ audio_cache.append(recon)
344
+
345
+ processed_recon = _linear_overlap_add(audio_cache, stride=self.streaming_stride_samples)
346
+ processed_recon = processed_recon[n_decoded_samples:]
347
+ yield processed_recon