Hoanglinhn01 commited on
Commit
2cdfe2c
·
verified ·
1 Parent(s): 5d174e9

Upload 2 files

Browse files
Files changed (2) hide show
  1. vieneu_tts/__init__.py +4 -0
  2. vieneu_tts/vieneu_tts.py +895 -0
vieneu_tts/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .vieneu_tts import VieNeuTTS, FastVieNeuTTS
2
+
3
+ __all__ = ["VieNeuTTS", "FastVieNeuTTS"]
4
+
vieneu_tts/vieneu_tts.py ADDED
@@ -0,0 +1,895 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Generator
3
+ import librosa
4
+ import numpy as np
5
+ import torch
6
+ from neucodec import NeuCodec, DistillNeuCodec
7
+ from utils.phonemize_text import phonemize_with_dict
8
+ from collections import defaultdict
9
+ from concurrent.futures import ThreadPoolExecutor
10
+ import re
11
+ import gc
12
+
13
+ # ============================================================================
14
+ # Shared Utilities
15
+ # ============================================================================
16
+
17
+ def _linear_overlap_add(frames: list[np.ndarray], stride: int) -> np.ndarray:
18
+ """Linear overlap-add for smooth audio concatenation"""
19
+ assert len(frames)
20
+ dtype = frames[0].dtype
21
+ shape = frames[0].shape[:-1]
22
+
23
+ total_size = 0
24
+ for i, frame in enumerate(frames):
25
+ frame_end = stride * i + frame.shape[-1]
26
+ total_size = max(total_size, frame_end)
27
+
28
+ sum_weight = np.zeros(total_size, dtype=dtype)
29
+ out = np.zeros(*shape, total_size, dtype=dtype)
30
+
31
+ offset: int = 0
32
+ for frame in frames:
33
+ frame_length = frame.shape[-1]
34
+ t = np.linspace(0, 1, frame_length + 2, dtype=dtype)[1:-1]
35
+ weight = np.abs(0.5 - (t - 0.5))
36
+
37
+ out[..., offset : offset + frame_length] += weight * frame
38
+ sum_weight[offset : offset + frame_length] += weight
39
+ offset += stride
40
+ assert sum_weight.min() > 0
41
+ return out / sum_weight
42
+
43
+
44
+ def _compile_codec_with_triton(codec):
45
+ """Compile codec with Triton for faster decoding (Windows/Linux compatible)"""
46
+ try:
47
+ import triton
48
+
49
+ if hasattr(codec, 'dec') and hasattr(codec.dec, 'resblocks'):
50
+ if len(codec.dec.resblocks) > 2:
51
+ codec.dec.resblocks[2].forward = torch.compile(
52
+ codec.dec.resblocks[2].forward,
53
+ mode="reduce-overhead",
54
+ dynamic=True
55
+ )
56
+ print(" ✅ Triton compilation enabled for codec")
57
+ return True
58
+
59
+ except ImportError:
60
+ print(" ⚠️ Triton not found. Install for faster speed:")
61
+ print(" • Linux: pip install triton")
62
+ print(" • Windows: pip install triton-windows")
63
+ print(" (Optional but recommended)")
64
+ return False
65
+
66
+
67
+ # ============================================================================
68
+ # VieNeuTTS - Standard implementation (CPU/GPU compatible)
69
+ # Supports: PyTorch Transformers, GGUF/GGML quantized models
70
+ # ============================================================================
71
+
72
+ class VieNeuTTS:
73
+ """
74
+ Standard VieNeu-TTS implementation.
75
+
76
+ Supports:
77
+ - PyTorch + Transformers backend (CPU/GPU)
78
+ - GGUF quantized models via llama-cpp-python (CPU optimized)
79
+
80
+ Use this for:
81
+ - CPU-only environments
82
+ - Standard PyTorch workflows
83
+ - GGUF quantized models
84
+ """
85
+
86
+ def __init__(
87
+ self,
88
+ backbone_repo="pnnbao-ump/VieNeu-TTS",
89
+ backbone_device="cpu",
90
+ codec_repo="neuphonic/neucodec",
91
+ codec_device="cpu",
92
+ ):
93
+ """
94
+ Initialize VieNeu-TTS.
95
+
96
+ Args:
97
+ backbone_repo: Model repository or path to GGUF file
98
+ backbone_device: Device for backbone ('cpu', 'cuda', 'gpu')
99
+ codec_repo: Codec repository
100
+ codec_device: Device for codec
101
+ """
102
+
103
+ # Constants
104
+ self.sample_rate = 24_000
105
+ self.max_context = 2048
106
+ self.hop_length = 480
107
+ self.streaming_overlap_frames = 1
108
+ self.streaming_frames_per_chunk = 25
109
+ self.streaming_lookforward = 5
110
+ self.streaming_lookback = 50
111
+ self.streaming_stride_samples = self.streaming_frames_per_chunk * self.hop_length
112
+
113
+ # Flags
114
+ self._is_quantized_model = False
115
+ self._is_onnx_codec = False
116
+
117
+ # HF tokenizer
118
+ self.tokenizer = None
119
+
120
+ # Load models
121
+ self._load_backbone(backbone_repo, backbone_device)
122
+ self._load_codec(codec_repo, codec_device)
123
+
124
+ # Load watermarker (optional)
125
+ try:
126
+ import perth
127
+ self.watermarker = perth.PerthImplicitWatermarker()
128
+ print(" 🔒 Audio watermarking initialized (Perth)")
129
+ except (ImportError, AttributeError):
130
+ self.watermarker = None
131
+
132
+ def _load_backbone(self, backbone_repo, backbone_device):
133
+ # MPS device validation
134
+ if backbone_device == "mps":
135
+ if not torch.backends.mps.is_available():
136
+ print("Warning: MPS not available, falling back to CPU")
137
+ backbone_device = "cpu"
138
+
139
+ print(f"Loading backbone from: {backbone_repo} on {backbone_device} ...")
140
+
141
+ if backbone_repo.lower().endswith("gguf") or "gguf" in backbone_repo.lower():
142
+ try:
143
+ from llama_cpp import Llama
144
+ except ImportError as e:
145
+ raise ImportError(
146
+ "Failed to import `llama_cpp`. "
147
+ "Xem hướng dẫn cài đặt llama_cpp_python phiên bản tối thiểu 0.3.16 tại: https://llama-cpp-python.readthedocs.io/en/latest/"
148
+ ) from e
149
+ self.backbone = Llama.from_pretrained(
150
+ repo_id=backbone_repo,
151
+ filename="*.gguf",
152
+ verbose=False,
153
+ n_gpu_layers=-1 if backbone_device == "gpu" else 0,
154
+ n_ctx=self.max_context,
155
+ mlock=True,
156
+ flash_attn=True if backbone_device == "gpu" else False,
157
+ )
158
+ self._is_quantized_model = True
159
+
160
+ else:
161
+ from transformers import AutoTokenizer, AutoModelForCausalLM
162
+ self.tokenizer = AutoTokenizer.from_pretrained(backbone_repo)
163
+ self.backbone = AutoModelForCausalLM.from_pretrained(backbone_repo).to(
164
+ torch.device(backbone_device)
165
+ )
166
+
167
+ def _load_codec(self, codec_repo, codec_device):
168
+ # MPS device validation
169
+ if codec_device == "mps":
170
+ if not torch.backends.mps.is_available():
171
+ print("Warning: MPS not available for codec, falling back to CPU")
172
+ codec_device = "cpu"
173
+
174
+ print(f"Loading codec from: {codec_repo} on {codec_device} ...")
175
+ match codec_repo:
176
+ case "neuphonic/neucodec":
177
+ self.codec = NeuCodec.from_pretrained(codec_repo)
178
+ self.codec.eval().to(codec_device)
179
+ case "neuphonic/distill-neucodec":
180
+ self.codec = DistillNeuCodec.from_pretrained(codec_repo)
181
+ self.codec.eval().to(codec_device)
182
+ case "neuphonic/neucodec-onnx-decoder-int8":
183
+ if codec_device != "cpu":
184
+ raise ValueError("Onnx decoder only currently runs on CPU.")
185
+ try:
186
+ from neucodec import NeuCodecOnnxDecoder
187
+ except ImportError as e:
188
+ raise ImportError(
189
+ "Failed to import the onnx decoder."
190
+ "Ensure you have onnxruntime installed as well as neucodec >= 0.0.4."
191
+ ) from e
192
+ self.codec = NeuCodecOnnxDecoder.from_pretrained(codec_repo)
193
+ self._is_onnx_codec = True
194
+ case _:
195
+ raise ValueError(f"Unsupported codec repository: {codec_repo}")
196
+
197
+ def encode_reference(self, ref_audio_path: str | Path):
198
+ """Encode reference audio to codes"""
199
+ wav, _ = librosa.load(ref_audio_path, sr=16000, mono=True)
200
+ wav_tensor = torch.from_numpy(wav).float().unsqueeze(0).unsqueeze(0) # [1, 1, T]
201
+ with torch.no_grad():
202
+ ref_codes = self.codec.encode_code(audio_or_path=wav_tensor).squeeze(0).squeeze(0)
203
+ return ref_codes
204
+
205
+ def infer(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> np.ndarray:
206
+ """
207
+ Perform inference to generate speech from text using the TTS model and reference audio.
208
+
209
+ Args:
210
+ text (str): Input text to be converted to speech.
211
+ ref_codes (np.ndarray | torch.tensor): Encoded reference.
212
+ ref_text (str): Reference text for reference audio.
213
+ Returns:
214
+ np.ndarray: Generated speech waveform.
215
+ """
216
+
217
+ # Generate tokens
218
+ if self._is_quantized_model:
219
+ output_str = self._infer_ggml(ref_codes, ref_text, text)
220
+ else:
221
+ prompt_ids = self._apply_chat_template(ref_codes, ref_text, text)
222
+ output_str = self._infer_torch(prompt_ids)
223
+
224
+ # Decode
225
+ wav = self._decode(output_str)
226
+
227
+ # Apply watermark if available
228
+ if self.watermarker:
229
+ wav = self.watermarker.apply_watermark(wav, sample_rate=self.sample_rate)
230
+
231
+ return wav
232
+
233
+ def infer_stream(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> Generator[np.ndarray, None, None]:
234
+ """
235
+ Perform streaming inference to generate speech from text using the TTS model and reference audio.
236
+
237
+ Args:
238
+ text (str): Input text to be converted to speech.
239
+ ref_codes (np.ndarray | torch.tensor): Encoded reference.
240
+ ref_text (str): Reference text for reference audio.
241
+ Yields:
242
+ np.ndarray: Generated speech waveform.
243
+ """
244
+
245
+ if self._is_quantized_model:
246
+ return self._infer_stream_ggml(ref_codes, ref_text, text)
247
+ else:
248
+ raise NotImplementedError("Streaming is not implemented for the torch backend!")
249
+
250
+ def _decode(self, codes: str):
251
+ """Decode speech tokens to audio waveform."""
252
+ # Extract speech token IDs using regex
253
+ speech_ids = [int(num) for num in re.findall(r"<\|speech_(\d+)\|>", codes)]
254
+
255
+ if len(speech_ids) == 0:
256
+ raise ValueError(
257
+ "No valid speech tokens found in the output. "
258
+ "Lỗi này có thể do GPU của bạn không hỗ trợ định dạng bfloat16 (ví dụ: dòng T4, RTX 20-series) "
259
+ "dẫn đến sai số khi tính toán. Bạn hãy thử chuyển sang dùng phiên bản GGUF Q4/Q8 hoặc "
260
+ "bỏ chọn 'LMDeploy' trong Tùy chọn nâng cao."
261
+ )
262
+
263
+ # Onnx decode
264
+ if self._is_onnx_codec:
265
+ codes = np.array(speech_ids, dtype=np.int32)[np.newaxis, np.newaxis, :]
266
+ recon = self.codec.decode_code(codes)
267
+ # Torch decode
268
+ else:
269
+ with torch.no_grad():
270
+ codes = torch.tensor(speech_ids, dtype=torch.long)[None, None, :].to(
271
+ self.codec.device
272
+ )
273
+ recon = self.codec.decode_code(codes).cpu().numpy()
274
+
275
+ return recon[0, 0, :]
276
+
277
+ def _apply_chat_template(self, ref_codes: list[int], ref_text: str, input_text: str) -> list[int]:
278
+ input_text = phonemize_with_dict(ref_text) + " " + phonemize_with_dict(input_text)
279
+
280
+ speech_replace = self.tokenizer.convert_tokens_to_ids("<|SPEECH_REPLACE|>")
281
+ speech_gen_start = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_START|>")
282
+ text_replace = self.tokenizer.convert_tokens_to_ids("<|TEXT_REPLACE|>")
283
+ text_prompt_start = self.tokenizer.convert_tokens_to_ids("<|TEXT_PROMPT_START|>")
284
+ text_prompt_end = self.tokenizer.convert_tokens_to_ids("<|TEXT_PROMPT_END|>")
285
+
286
+ input_ids = self.tokenizer.encode(input_text, add_special_tokens=False)
287
+ chat = """user: Convert the text to speech:<|TEXT_REPLACE|>\nassistant:<|SPEECH_REPLACE|>"""
288
+ ids = self.tokenizer.encode(chat)
289
+
290
+ text_replace_idx = ids.index(text_replace)
291
+ ids = (
292
+ ids[:text_replace_idx]
293
+ + [text_prompt_start]
294
+ + input_ids
295
+ + [text_prompt_end]
296
+ + ids[text_replace_idx + 1 :] # noqa
297
+ )
298
+
299
+ speech_replace_idx = ids.index(speech_replace)
300
+ codes_str = "".join([f"<|speech_{i}|>" for i in ref_codes])
301
+ codes = self.tokenizer.encode(codes_str, add_special_tokens=False)
302
+ ids = ids[:speech_replace_idx] + [speech_gen_start] + list(codes)
303
+
304
+ return ids
305
+
306
+ def _infer_torch(self, prompt_ids: list[int]) -> str:
307
+ prompt_tensor = torch.tensor(prompt_ids).unsqueeze(0).to(self.backbone.device)
308
+ speech_end_id = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>")
309
+ with torch.no_grad():
310
+ output_tokens = self.backbone.generate(
311
+ prompt_tensor,
312
+ max_length=self.max_context,
313
+ eos_token_id=speech_end_id,
314
+ do_sample=True,
315
+ temperature=1.0,
316
+ top_k=50,
317
+ use_cache=True,
318
+ min_new_tokens=50,
319
+ )
320
+ input_length = prompt_tensor.shape[-1]
321
+ output_str = self.tokenizer.decode(
322
+ output_tokens[0, input_length:].cpu().numpy().tolist(), add_special_tokens=False
323
+ )
324
+ return output_str
325
+
326
+ def _infer_ggml(self, ref_codes: list[int], ref_text: str, input_text: str) -> str:
327
+ ref_text = phonemize_with_dict(ref_text)
328
+ input_text = phonemize_with_dict(input_text)
329
+
330
+ codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes])
331
+ prompt = (
332
+ f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text} {input_text}"
333
+ f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}"
334
+ )
335
+ output = self.backbone(
336
+ prompt,
337
+ max_tokens=self.max_context,
338
+ temperature=1.0,
339
+ top_k=50,
340
+ stop=["<|SPEECH_GENERATION_END|>"],
341
+ )
342
+ output_str = output["choices"][0]["text"]
343
+ return output_str
344
+
345
+ def _infer_stream_ggml(self, ref_codes: torch.Tensor, ref_text: str, input_text: str) -> Generator[np.ndarray, None, None]:
346
+ ref_text = phonemize_with_dict(ref_text)
347
+ input_text = phonemize_with_dict(input_text)
348
+
349
+ codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes])
350
+ prompt = (
351
+ f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text} {input_text}"
352
+ f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}"
353
+ )
354
+
355
+ audio_cache: list[np.ndarray] = []
356
+ token_cache: list[str] = [f"<|speech_{idx}|>" for idx in ref_codes]
357
+ n_decoded_samples: int = 0
358
+ n_decoded_tokens: int = len(ref_codes)
359
+
360
+ for item in self.backbone(
361
+ prompt,
362
+ max_tokens=self.max_context,
363
+ temperature=1.0,
364
+ top_k=50,
365
+ stop=["<|SPEECH_GENERATION_END|>"],
366
+ stream=True
367
+ ):
368
+ output_str = item["choices"][0]["text"]
369
+ token_cache.append(output_str)
370
+
371
+ if len(token_cache[n_decoded_tokens:]) >= self.streaming_frames_per_chunk + self.streaming_lookforward:
372
+
373
+ # decode chunk
374
+ tokens_start = max(
375
+ n_decoded_tokens
376
+ - self.streaming_lookback
377
+ - self.streaming_overlap_frames,
378
+ 0
379
+ )
380
+ tokens_end = (
381
+ n_decoded_tokens
382
+ + self.streaming_frames_per_chunk
383
+ + self.streaming_lookforward
384
+ + self.streaming_overlap_frames
385
+ )
386
+ sample_start = (
387
+ n_decoded_tokens - tokens_start
388
+ ) * self.hop_length
389
+ sample_end = (
390
+ sample_start
391
+ + (self.streaming_frames_per_chunk + 2 * self.streaming_overlap_frames) * self.hop_length
392
+ )
393
+ curr_codes = token_cache[tokens_start:tokens_end]
394
+ recon = self._decode("".join(curr_codes))
395
+ recon = recon[sample_start:sample_end]
396
+ audio_cache.append(recon)
397
+
398
+ # postprocess
399
+ processed_recon = _linear_overlap_add(
400
+ audio_cache, stride=self.streaming_stride_samples
401
+ )
402
+ new_samples_end = len(audio_cache) * self.streaming_stride_samples
403
+ processed_recon = processed_recon[
404
+ n_decoded_samples:new_samples_end
405
+ ]
406
+ n_decoded_samples = new_samples_end
407
+ n_decoded_tokens += self.streaming_frames_per_chunk
408
+ yield processed_recon
409
+
410
+ # final decoding handled separately as non-constant chunk size
411
+ remaining_tokens = len(token_cache) - n_decoded_tokens
412
+ if len(token_cache) > n_decoded_tokens:
413
+ tokens_start = max(
414
+ len(token_cache)
415
+ - (self.streaming_lookback + self.streaming_overlap_frames + remaining_tokens),
416
+ 0
417
+ )
418
+ sample_start = (
419
+ len(token_cache)
420
+ - tokens_start
421
+ - remaining_tokens
422
+ - self.streaming_overlap_frames
423
+ ) * self.hop_length
424
+ curr_codes = token_cache[tokens_start:]
425
+ recon = self._decode("".join(curr_codes))
426
+ recon = recon[sample_start:]
427
+ audio_cache.append(recon)
428
+
429
+ processed_recon = _linear_overlap_add(audio_cache, stride=self.streaming_stride_samples)
430
+ processed_recon = processed_recon[n_decoded_samples:]
431
+ yield processed_recon
432
+
433
+
434
+ # ============================================================================
435
+ # FastVieNeuTTS - GPU-optimized implementation
436
+ # Requires: LMDeploy with CUDA
437
+ # ============================================================================
438
+
439
+ class FastVieNeuTTS:
440
+ """
441
+ GPU-optimized VieNeu-TTS using LMDeploy TurbomindEngine.
442
+ """
443
+
444
+ def __init__(
445
+ self,
446
+ backbone_repo="pnnbao-ump/VieNeu-TTS",
447
+ backbone_device="cuda",
448
+ codec_repo="neuphonic/neucodec",
449
+ codec_device="cuda",
450
+ memory_util=0.3,
451
+ tp=1,
452
+ enable_prefix_caching=True,
453
+ quant_policy=0,
454
+ enable_triton=True,
455
+ max_batch_size=8,
456
+ ):
457
+ """
458
+ Initialize FastVieNeuTTS with LMDeploy backend and optimizations.
459
+
460
+ Args:
461
+ backbone_repo: Model repository
462
+ backbone_device: Device for backbone (must be CUDA)
463
+ codec_repo: Codec repository
464
+ codec_device: Device for codec
465
+ memory_util: GPU memory utilization (0.0-1.0)
466
+ tp: Tensor parallel size for multi-GPU
467
+ enable_prefix_caching: Enable prefix caching for faster batch processing
468
+ quant_policy: KV cache quantization (0=off, 8=int8, 4=int4)
469
+ enable_triton: Enable Triton compilation for codec
470
+ max_batch_size: Maximum batch size for inference (prevent GPU overload)
471
+ """
472
+
473
+ if backbone_device != "cuda" and not backbone_device.startswith("cuda:"):
474
+ raise ValueError("LMDeploy backend requires CUDA device")
475
+
476
+ # Constants
477
+ self.sample_rate = 24_000
478
+ self.max_context = 2048
479
+ self.hop_length = 480
480
+ self.streaming_overlap_frames = 1
481
+ self.streaming_frames_per_chunk = 50
482
+ self.streaming_lookforward = 5
483
+ self.streaming_lookback = 50
484
+ self.streaming_stride_samples = self.streaming_frames_per_chunk * self.hop_length
485
+
486
+ self.max_batch_size = max_batch_size
487
+
488
+ self._ref_cache = {}
489
+
490
+ self.stored_dict = defaultdict(dict)
491
+
492
+ # Flags
493
+ self._is_onnx_codec = False
494
+ self._triton_enabled = False
495
+
496
+ # Load models
497
+ self._load_backbone_lmdeploy(backbone_repo, memory_util, tp, enable_prefix_caching, quant_policy)
498
+ self._load_codec(codec_repo, codec_device, enable_triton)
499
+
500
+ # Load watermarker (optional)
501
+ try:
502
+ import perth
503
+ self.watermarker = perth.PerthImplicitWatermarker()
504
+ print(" 🔒 Audio watermarking initialized (Perth)")
505
+ except (ImportError, AttributeError):
506
+ self.watermarker = None
507
+
508
+ self._warmup_model()
509
+
510
+ print("✅ FastVieNeuTTS with optimizations loaded successfully!")
511
+ print(f" Max batch size: {self.max_batch_size} (adjustable to prevent GPU overload)")
512
+
513
+ def _load_backbone_lmdeploy(self, repo, memory_util, tp, enable_prefix_caching, quant_policy):
514
+ """Load backbone using LMDeploy's TurbomindEngine"""
515
+ print(f"Loading backbone with LMDeploy from: {repo}")
516
+
517
+ try:
518
+ from lmdeploy import pipeline, TurbomindEngineConfig, GenerationConfig
519
+ except ImportError as e:
520
+ raise ImportError(
521
+ "Failed to import `lmdeploy`. "
522
+ "Xem hướng dẫn cài đặt lmdeploy để tối ưu hiệu suất GPU tại: https://github.com/pnnbao97/VieNeu-TTS"
523
+ ) from e
524
+
525
+ backend_config = TurbomindEngineConfig(
526
+ cache_max_entry_count=memory_util,
527
+ tp=tp,
528
+ enable_prefix_caching=enable_prefix_caching,
529
+ dtype='bfloat16',
530
+ quant_policy=quant_policy
531
+ )
532
+
533
+ self.backbone = pipeline(repo, backend_config=backend_config)
534
+
535
+ self.gen_config = GenerationConfig(
536
+ top_p=0.95,
537
+ top_k=50,
538
+ temperature=1.0,
539
+ max_new_tokens=2048,
540
+ do_sample=True,
541
+ min_new_tokens=40,
542
+ )
543
+
544
+ print(f" LMDeploy TurbomindEngine initialized")
545
+ print(f" - Memory util: {memory_util}")
546
+ print(f" - Tensor Parallel: {tp}")
547
+ print(f" - Prefix caching: {enable_prefix_caching}")
548
+ print(f" - KV quant: {quant_policy} ({'Enabled' if quant_policy > 0 else 'Disabled'})")
549
+
550
+ def _load_codec(self, codec_repo, codec_device, enable_triton):
551
+ """Load codec with optional Triton compilation"""
552
+ print(f"Loading codec from: {codec_repo} on {codec_device}")
553
+
554
+ match codec_repo:
555
+ case "neuphonic/neucodec":
556
+ self.codec = NeuCodec.from_pretrained(codec_repo)
557
+ self.codec.eval().to(codec_device)
558
+ case "neuphonic/distill-neucodec":
559
+ self.codec = DistillNeuCodec.from_pretrained(codec_repo)
560
+ self.codec.eval().to(codec_device)
561
+ case "neuphonic/neucodec-onnx-decoder-int8":
562
+ if codec_device != "cpu":
563
+ raise ValueError("ONNX decoder only runs on CPU")
564
+ try:
565
+ from neucodec import NeuCodecOnnxDecoder
566
+ except ImportError as e:
567
+ raise ImportError(
568
+ "Failed to import ONNX decoder. "
569
+ "Ensure onnxruntime and neucodec >= 0.0.4 are installed."
570
+ ) from e
571
+ self.codec = NeuCodecOnnxDecoder.from_pretrained(codec_repo)
572
+ self._is_onnx_codec = True
573
+ case _:
574
+ raise ValueError(f"Unsupported codec repository: {codec_repo}")
575
+
576
+ if enable_triton and not self._is_onnx_codec and codec_device != "cpu":
577
+ self._triton_enabled = _compile_codec_with_triton(self.codec)
578
+
579
+ def _warmup_model(self):
580
+ """Warmup inference pipeline to reduce first-token latency"""
581
+ print("🔥 Warming up model...")
582
+ try:
583
+ dummy_codes = list(range(10))
584
+ dummy_prompt = self._format_prompt(dummy_codes, "warmup", "test")
585
+ _ = self.backbone([dummy_prompt], gen_config=self.gen_config, do_preprocess=False)
586
+ print(" ✅ Warmup complete")
587
+ except Exception as e:
588
+ print(f" ⚠️ Warmup failed (non-critical): {e}")
589
+
590
+ def encode_reference(self, ref_audio_path: str | Path):
591
+ """Encode reference audio to codes"""
592
+ wav, _ = librosa.load(ref_audio_path, sr=16000, mono=True)
593
+ wav_tensor = torch.from_numpy(wav).float().unsqueeze(0).unsqueeze(0)
594
+ with torch.no_grad():
595
+ ref_codes = self.codec.encode_code(audio_or_path=wav_tensor).squeeze(0).squeeze(0)
596
+ return ref_codes
597
+
598
+ def get_cached_reference(self, voice_name: str, audio_path: str, ref_text: str = None):
599
+ """
600
+ Get or create cached reference codes.
601
+
602
+ Args:
603
+ voice_name: Unique identifier for this voice
604
+ audio_path: Path to reference audio
605
+ ref_text: Optional reference text (stored with codes)
606
+
607
+ Returns:
608
+ ref_codes: Encoded reference codes
609
+ """
610
+ cache_key = f"{voice_name}_{audio_path}"
611
+
612
+ if cache_key not in self._ref_cache:
613
+ ref_codes = self.encode_reference(audio_path)
614
+ self._ref_cache[cache_key] = {
615
+ 'codes': ref_codes,
616
+ 'ref_text': ref_text
617
+ }
618
+
619
+ return self._ref_cache[cache_key]['codes']
620
+
621
+ def add_speaker(self, user_id: int, audio_file: str, ref_text: str):
622
+ """
623
+ Add a speaker to the stored dictionary for easy access.
624
+
625
+ Args:
626
+ user_id: Unique user ID
627
+ audio_file: Reference audio file path
628
+ ref_text: Reference text
629
+
630
+ Returns:
631
+ user_id: The user ID for use in streaming
632
+ """
633
+ codes = self.encode_reference(audio_file)
634
+
635
+ if isinstance(codes, torch.Tensor):
636
+ codes = codes.cpu().numpy()
637
+ if isinstance(codes, np.ndarray):
638
+ codes = codes.flatten().tolist()
639
+
640
+ self.stored_dict[f"{user_id}"]['codes'] = codes
641
+ self.stored_dict[f"{user_id}"]['ref_text'] = ref_text
642
+
643
+ return user_id
644
+
645
+ def _decode(self, codes: str):
646
+ """Decode speech tokens to audio waveform"""
647
+ speech_ids = [int(num) for num in re.findall(r"<\|speech_(\d+)\|>", codes)]
648
+
649
+ if len(speech_ids) == 0:
650
+ raise ValueError(
651
+ "No valid speech tokens found in the output. "
652
+ "Lỗi này có thể do GPU của bạn không hỗ trợ định dạng bfloat16 (ví dụ: dòng T4, RTX 20-series) "
653
+ "khiến mô hình chạy không ổn định trên LMDeploy (Turbomind). Bạn hãy thử bỏ chọn 'LMDeploy' "
654
+ "trong Tùy chọn nâng cao hoặc chuyển sang dùng phiên bản GGUF Q4/Q8 để chạy ổn định hơn."
655
+ )
656
+
657
+ if self._is_onnx_codec:
658
+ codes = np.array(speech_ids, dtype=np.int32)[np.newaxis, np.newaxis, :]
659
+ recon = self.codec.decode_code(codes)
660
+ else:
661
+ with torch.no_grad():
662
+ codes = torch.tensor(speech_ids, dtype=torch.long)[None, None, :].to(
663
+ self.codec.device
664
+ )
665
+ recon = self.codec.decode_code(codes).cpu().numpy()
666
+
667
+ return recon[0, 0, :]
668
+
669
+ def _decode_batch(self, codes_list: list[str], max_workers: int = None):
670
+ """
671
+ Decode multiple code strings in parallel.
672
+
673
+ Args:
674
+ codes_list: List of code strings to decode
675
+ max_workers: Number of parallel workers (auto-tuned if None)
676
+
677
+ Returns:
678
+ List of decoded audio arrays
679
+ """
680
+ # Auto-tune workers based on GPU memory and batch size
681
+ if max_workers is None:
682
+ if torch.cuda.is_available():
683
+ gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
684
+ # 1 worker per 4GB VRAM, max 4 workers
685
+ max_workers = min(max(1, int(gpu_mem_gb / 4)), 4)
686
+ else:
687
+ max_workers = 2
688
+
689
+ # For small batches, use sequential to avoid overhead
690
+ if len(codes_list) <= 2:
691
+ return [self._decode(codes) for codes in codes_list]
692
+
693
+ # Parallel decoding with controlled workers
694
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
695
+ futures = [executor.submit(self._decode, codes) for codes in codes_list]
696
+ results = [f.result() for f in futures]
697
+ return results
698
+
699
+ def _format_prompt(self, ref_codes: list[int], ref_text: str, input_text: str) -> str:
700
+ """Format prompt for LMDeploy"""
701
+ ref_text_phones = phonemize_with_dict(ref_text)
702
+ input_text_phones = phonemize_with_dict(input_text)
703
+
704
+ codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes])
705
+
706
+ prompt = (
707
+ f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text_phones} {input_text_phones}"
708
+ f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}"
709
+ )
710
+
711
+ return prompt
712
+
713
+ def infer(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> np.ndarray:
714
+ """
715
+ Single inference.
716
+
717
+ Args:
718
+ text: Input text to synthesize
719
+ ref_codes: Encoded reference audio codes
720
+ ref_text: Reference text for reference audio
721
+
722
+ Returns:
723
+ Generated speech waveform as numpy array
724
+ """
725
+ if isinstance(ref_codes, torch.Tensor):
726
+ ref_codes = ref_codes.cpu().numpy()
727
+ if isinstance(ref_codes, np.ndarray):
728
+ ref_codes = ref_codes.flatten().tolist()
729
+
730
+ prompt = self._format_prompt(ref_codes, ref_text, text)
731
+
732
+ # Use LMDeploy pipeline for generation
733
+ responses = self.backbone([prompt], gen_config=self.gen_config, do_preprocess=False)
734
+ output_str = responses[0].text
735
+
736
+ # Decode to audio
737
+ wav = self._decode(output_str)
738
+
739
+ # Apply watermark if available
740
+ if self.watermarker:
741
+ wav = self.watermarker.apply_watermark(wav, sample_rate=self.sample_rate)
742
+
743
+ return wav
744
+
745
+ def infer_batch(self, texts: list[str], ref_codes: np.ndarray | torch.Tensor, ref_text: str, max_batch_size: int = None) -> list[np.ndarray]:
746
+ """
747
+ Batch inference for multiple texts.
748
+ """
749
+ if max_batch_size is None:
750
+ max_batch_size = self.max_batch_size
751
+
752
+ if not isinstance(texts, list):
753
+ texts = [texts]
754
+
755
+ if isinstance(ref_codes, torch.Tensor):
756
+ ref_codes = ref_codes.cpu().numpy()
757
+ if isinstance(ref_codes, np.ndarray):
758
+ ref_codes = ref_codes.flatten().tolist()
759
+
760
+ all_wavs = []
761
+
762
+ for i in range(0, len(texts), max_batch_size):
763
+ batch_texts = texts[i:i+max_batch_size]
764
+ prompts = [self._format_prompt(ref_codes, ref_text, text) for text in batch_texts]
765
+ responses = self.backbone(prompts, gen_config=self.gen_config, do_preprocess=False)
766
+ batch_codes = [response.text for response in responses]
767
+
768
+ if len(batch_codes) > 3:
769
+ batch_wavs = self._decode_batch(batch_codes)
770
+ else:
771
+ batch_wavs = [self._decode(codes) for codes in batch_codes]
772
+
773
+ # Apply watermark if available
774
+ if self.watermarker:
775
+ batch_wavs = [self.watermarker.apply_watermark(w, sample_rate=self.sample_rate) for w in batch_wavs]
776
+
777
+ all_wavs.extend(batch_wavs)
778
+
779
+ if i + max_batch_size < len(texts):
780
+ if torch.cuda.is_available():
781
+ torch.cuda.empty_cache()
782
+
783
+ return all_wavs
784
+
785
+ def infer_stream(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> Generator[np.ndarray, None, None]:
786
+ """
787
+ Streaming inference with low latency.
788
+
789
+ Args:
790
+ text: Input text to synthesize
791
+ ref_codes: Encoded reference audio codes
792
+ ref_text: Reference text for reference audio
793
+
794
+ Yields:
795
+ Audio chunks as numpy arrays
796
+ """
797
+ if isinstance(ref_codes, torch.Tensor):
798
+ ref_codes = ref_codes.cpu().numpy()
799
+ if isinstance(ref_codes, np.ndarray):
800
+ ref_codes = ref_codes.flatten().tolist()
801
+
802
+ prompt = self._format_prompt(ref_codes, ref_text, text)
803
+
804
+ audio_cache = []
805
+ token_cache = [f"<|speech_{idx}|>" for idx in ref_codes]
806
+ n_decoded_samples = 0
807
+ n_decoded_tokens = len(ref_codes)
808
+
809
+ for response in self.backbone.stream_infer([prompt], gen_config=self.gen_config, do_preprocess=False):
810
+ output_str = response.text
811
+
812
+ # Extract new tokens
813
+ new_tokens = output_str[len("".join(token_cache[len(ref_codes):])):] if len(token_cache) > len(ref_codes) else output_str
814
+
815
+ if new_tokens:
816
+ token_cache.append(new_tokens)
817
+
818
+ # Check if we have enough tokens to decode a chunk
819
+ if len(token_cache[n_decoded_tokens:]) >= self.streaming_frames_per_chunk + self.streaming_lookforward:
820
+
821
+ # Decode chunk with context
822
+ tokens_start = max(
823
+ n_decoded_tokens - self.streaming_lookback - self.streaming_overlap_frames,
824
+ 0
825
+ )
826
+ tokens_end = (
827
+ n_decoded_tokens
828
+ + self.streaming_frames_per_chunk
829
+ + self.streaming_lookforward
830
+ + self.streaming_overlap_frames
831
+ )
832
+ sample_start = (n_decoded_tokens - tokens_start) * self.hop_length
833
+ sample_end = (
834
+ sample_start
835
+ + (self.streaming_frames_per_chunk + 2 * self.streaming_overlap_frames) * self.hop_length
836
+ )
837
+
838
+ curr_codes = token_cache[tokens_start:tokens_end]
839
+ recon = self._decode("".join(curr_codes))
840
+ recon = recon[sample_start:sample_end]
841
+ audio_cache.append(recon)
842
+
843
+ # Overlap-add processing
844
+ processed_recon = _linear_overlap_add(
845
+ audio_cache, stride=self.streaming_stride_samples
846
+ )
847
+ new_samples_end = len(audio_cache) * self.streaming_stride_samples
848
+ processed_recon = processed_recon[n_decoded_samples:new_samples_end]
849
+ n_decoded_samples = new_samples_end
850
+ n_decoded_tokens += self.streaming_frames_per_chunk
851
+
852
+ yield processed_recon
853
+
854
+ # Final chunk
855
+ remaining_tokens = len(token_cache) - n_decoded_tokens
856
+ if remaining_tokens > 0:
857
+ tokens_start = max(
858
+ len(token_cache) - (self.streaming_lookback + self.streaming_overlap_frames + remaining_tokens),
859
+ 0
860
+ )
861
+ sample_start = (
862
+ len(token_cache) - tokens_start - remaining_tokens - self.streaming_overlap_frames
863
+ ) * self.hop_length
864
+
865
+ curr_codes = token_cache[tokens_start:]
866
+ recon = self._decode("".join(curr_codes))
867
+ recon = recon[sample_start:]
868
+ audio_cache.append(recon)
869
+
870
+ processed_recon = _linear_overlap_add(audio_cache, stride=self.streaming_stride_samples)
871
+ processed_recon = processed_recon[n_decoded_samples:]
872
+ yield processed_recon
873
+
874
+ def cleanup_memory(self):
875
+ """Clean up GPU memory"""
876
+ if torch.cuda.is_available():
877
+ torch.cuda.empty_cache()
878
+ gc.collect()
879
+ print("🧹 Memory cleaned up")
880
+
881
+ def get_optimization_stats(self) -> dict:
882
+ """
883
+ Get current optimization statistics.
884
+
885
+ Returns:
886
+ Dictionary with optimization info
887
+ """
888
+ return {
889
+ 'triton_enabled': self._triton_enabled,
890
+ 'max_batch_size': self.max_batch_size,
891
+ 'cached_references': len(self._ref_cache),
892
+ 'active_sessions': len(self.stored_dict),
893
+ 'kv_quant': self.gen_config.__dict__.get('quant_policy', 0),
894
+ 'prefix_caching': True, # Always enabled in our config
895
+ }