pnnbao-ump commited on
Commit
c3a6523
·
verified ·
1 Parent(s): f202770

Delete vieneu_tts

Browse files
vieneu_tts/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- from .vieneu_tts import VieNeuTTS, FastVieNeuTTS
2
-
3
- __all__ = ["VieNeuTTS", "FastVieNeuTTS"]
4
-
 
 
 
 
 
vieneu_tts/__pycache__/__init__.cpython-312.pyc DELETED
Binary file (241 Bytes)
 
vieneu_tts/__pycache__/vieneu_tts.cpython-312.pyc DELETED
Binary file (39 kB)
 
vieneu_tts/__pycache__/vieneu_tts_gpu.cpython-312.pyc DELETED
Binary file (24.1 kB)
 
vieneu_tts/vieneu_tts.py DELETED
@@ -1,869 +0,0 @@
1
- from pathlib import Path
2
- from typing import Generator
3
- import librosa
4
- import numpy as np
5
- import torch
6
- from neucodec import NeuCodec, DistillNeuCodec
7
- from utils.phonemize_text import phonemize_with_dict
8
- from collections import defaultdict
9
- from concurrent.futures import ThreadPoolExecutor
10
- import re
11
- import gc
12
-
13
- # ============================================================================
14
- # Shared Utilities
15
- # ============================================================================
16
-
17
- def _linear_overlap_add(frames: list[np.ndarray], stride: int) -> np.ndarray:
18
- """Linear overlap-add for smooth audio concatenation"""
19
- assert len(frames)
20
- dtype = frames[0].dtype
21
- shape = frames[0].shape[:-1]
22
-
23
- total_size = 0
24
- for i, frame in enumerate(frames):
25
- frame_end = stride * i + frame.shape[-1]
26
- total_size = max(total_size, frame_end)
27
-
28
- sum_weight = np.zeros(total_size, dtype=dtype)
29
- out = np.zeros(*shape, total_size, dtype=dtype)
30
-
31
- offset: int = 0
32
- for frame in frames:
33
- frame_length = frame.shape[-1]
34
- t = np.linspace(0, 1, frame_length + 2, dtype=dtype)[1:-1]
35
- weight = np.abs(0.5 - (t - 0.5))
36
-
37
- out[..., offset : offset + frame_length] += weight * frame
38
- sum_weight[offset : offset + frame_length] += weight
39
- offset += stride
40
- assert sum_weight.min() > 0
41
- return out / sum_weight
42
-
43
-
44
- def _compile_codec_with_triton(codec):
45
- """Compile codec with Triton for faster decoding (Windows/Linux compatible)"""
46
- try:
47
- import triton
48
-
49
- if hasattr(codec, 'dec') and hasattr(codec.dec, 'resblocks'):
50
- if len(codec.dec.resblocks) > 2:
51
- codec.dec.resblocks[2].forward = torch.compile(
52
- codec.dec.resblocks[2].forward,
53
- mode="reduce-overhead",
54
- dynamic=True
55
- )
56
- print(" ✅ Triton compilation enabled for codec")
57
- return True
58
-
59
- except ImportError:
60
- print(" ⚠️ Triton not found. Install for faster speed:")
61
- print(" • Linux: pip install triton")
62
- print(" • Windows: pip install triton-windows")
63
- print(" (Optional but recommended)")
64
- return False
65
-
66
-
67
- # ============================================================================
68
- # VieNeuTTS - Standard implementation (CPU/GPU compatible)
69
- # Supports: PyTorch Transformers, GGUF/GGML quantized models
70
- # ============================================================================
71
-
72
- class VieNeuTTS:
73
- """
74
- Standard VieNeu-TTS implementation.
75
-
76
- Supports:
77
- - PyTorch + Transformers backend (CPU/GPU)
78
- - GGUF quantized models via llama-cpp-python (CPU optimized)
79
-
80
- Use this for:
81
- - CPU-only environments
82
- - Standard PyTorch workflows
83
- - GGUF quantized models
84
- """
85
-
86
- def __init__(
87
- self,
88
- backbone_repo="pnnbao-ump/VieNeu-TTS",
89
- backbone_device="cpu",
90
- codec_repo="neuphonic/neucodec",
91
- codec_device="cpu",
92
- ):
93
- """
94
- Initialize VieNeu-TTS.
95
-
96
- Args:
97
- backbone_repo: Model repository or path to GGUF file
98
- backbone_device: Device for backbone ('cpu', 'cuda', 'gpu')
99
- codec_repo: Codec repository
100
- codec_device: Device for codec
101
- """
102
-
103
- # Constants
104
- self.sample_rate = 24_000
105
- self.max_context = 2048
106
- self.hop_length = 480
107
- self.streaming_overlap_frames = 1
108
- self.streaming_frames_per_chunk = 25
109
- self.streaming_lookforward = 5
110
- self.streaming_lookback = 50
111
- self.streaming_stride_samples = self.streaming_frames_per_chunk * self.hop_length
112
-
113
- # Flags
114
- self._is_quantized_model = False
115
- self._is_onnx_codec = False
116
-
117
- # HF tokenizer
118
- self.tokenizer = None
119
-
120
- # Load models
121
- self._load_backbone(backbone_repo, backbone_device)
122
- self._load_codec(codec_repo, codec_device)
123
-
124
- def _load_backbone(self, backbone_repo, backbone_device):
125
- print(f"Loading backbone from: {backbone_repo} on {backbone_device} ...")
126
-
127
- if backbone_repo.lower().endswith("gguf") or "gguf" in backbone_repo.lower():
128
- try:
129
- from llama_cpp import Llama
130
- except ImportError as e:
131
- raise ImportError(
132
- "Failed to import `llama_cpp`. "
133
- "Please install it with:\n"
134
- " pip install llama-cpp-python"
135
- ) from e
136
- self.backbone = Llama.from_pretrained(
137
- repo_id=backbone_repo,
138
- filename="*.gguf",
139
- verbose=False,
140
- n_gpu_layers=-1 if backbone_device == "gpu" else 0,
141
- n_ctx=self.max_context,
142
- mlock=True,
143
- flash_attn=True if backbone_device == "gpu" else False,
144
- )
145
- self._is_quantized_model = True
146
-
147
- else:
148
- from transformers import AutoTokenizer, AutoModelForCausalLM
149
- self.tokenizer = AutoTokenizer.from_pretrained(backbone_repo)
150
- self.backbone = AutoModelForCausalLM.from_pretrained(backbone_repo).to(
151
- torch.device(backbone_device)
152
- )
153
-
154
- def _load_codec(self, codec_repo, codec_device):
155
- print(f"Loading codec from: {codec_repo} on {codec_device} ...")
156
- match codec_repo:
157
- case "neuphonic/neucodec":
158
- self.codec = NeuCodec.from_pretrained(codec_repo)
159
- self.codec.eval().to(codec_device)
160
- case "neuphonic/distill-neucodec":
161
- self.codec = DistillNeuCodec.from_pretrained(codec_repo)
162
- self.codec.eval().to(codec_device)
163
- case "neuphonic/neucodec-onnx-decoder":
164
- if codec_device != "cpu":
165
- raise ValueError("Onnx decoder only currently runs on CPU.")
166
- try:
167
- from neucodec import NeuCodecOnnxDecoder
168
- except ImportError as e:
169
- raise ImportError(
170
- "Failed to import the onnx decoder."
171
- "Ensure you have onnxruntime installed as well as neucodec >= 0.0.4."
172
- ) from e
173
- self.codec = NeuCodecOnnxDecoder.from_pretrained(codec_repo)
174
- self._is_onnx_codec = True
175
- case _:
176
- raise ValueError(f"Unsupported codec repository: {codec_repo}")
177
-
178
- def encode_reference(self, ref_audio_path: str | Path):
179
- """Encode reference audio to codes"""
180
- wav, _ = librosa.load(ref_audio_path, sr=16000, mono=True)
181
- wav_tensor = torch.from_numpy(wav).float().unsqueeze(0).unsqueeze(0) # [1, 1, T]
182
- with torch.no_grad():
183
- ref_codes = self.codec.encode_code(audio_or_path=wav_tensor).squeeze(0).squeeze(0)
184
- return ref_codes
185
-
186
- def infer(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> np.ndarray:
187
- """
188
- Perform inference to generate speech from text using the TTS model and reference audio.
189
-
190
- Args:
191
- text (str): Input text to be converted to speech.
192
- ref_codes (np.ndarray | torch.tensor): Encoded reference.
193
- ref_text (str): Reference text for reference audio.
194
- Returns:
195
- np.ndarray: Generated speech waveform.
196
- """
197
-
198
- # Generate tokens
199
- if self._is_quantized_model:
200
- output_str = self._infer_ggml(ref_codes, ref_text, text)
201
- else:
202
- prompt_ids = self._apply_chat_template(ref_codes, ref_text, text)
203
- output_str = self._infer_torch(prompt_ids)
204
-
205
- # Decode
206
- wav = self._decode(output_str)
207
-
208
- return wav
209
-
210
- def infer_stream(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> Generator[np.ndarray, None, None]:
211
- """
212
- Perform streaming inference to generate speech from text using the TTS model and reference audio.
213
-
214
- Args:
215
- text (str): Input text to be converted to speech.
216
- ref_codes (np.ndarray | torch.tensor): Encoded reference.
217
- ref_text (str): Reference text for reference audio.
218
- Yields:
219
- np.ndarray: Generated speech waveform.
220
- """
221
-
222
- if self._is_quantized_model:
223
- return self._infer_stream_ggml(ref_codes, ref_text, text)
224
- else:
225
- raise NotImplementedError("Streaming is not implemented for the torch backend!")
226
-
227
- def _decode(self, codes: str):
228
- """Decode speech tokens to audio waveform."""
229
- # Extract speech token IDs using regex
230
- speech_ids = [int(num) for num in re.findall(r"<\|speech_(\d+)\|>", codes)]
231
-
232
- if len(speech_ids) == 0:
233
- raise ValueError(
234
- "No valid speech tokens found in the output. "
235
- "The model may not have generated proper speech tokens."
236
- )
237
-
238
- # Onnx decode
239
- if self._is_onnx_codec:
240
- codes = np.array(speech_ids, dtype=np.int32)[np.newaxis, np.newaxis, :]
241
- recon = self.codec.decode_code(codes)
242
- # Torch decode
243
- else:
244
- with torch.no_grad():
245
- codes = torch.tensor(speech_ids, dtype=torch.long)[None, None, :].to(
246
- self.codec.device
247
- )
248
- recon = self.codec.decode_code(codes).cpu().numpy()
249
-
250
- return recon[0, 0, :]
251
-
252
- def _apply_chat_template(self, ref_codes: list[int], ref_text: str, input_text: str) -> list[int]:
253
- input_text = phonemize_with_dict(ref_text) + " " + phonemize_with_dict(input_text)
254
-
255
- speech_replace = self.tokenizer.convert_tokens_to_ids("<|SPEECH_REPLACE|>")
256
- speech_gen_start = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_START|>")
257
- text_replace = self.tokenizer.convert_tokens_to_ids("<|TEXT_REPLACE|>")
258
- text_prompt_start = self.tokenizer.convert_tokens_to_ids("<|TEXT_PROMPT_START|>")
259
- text_prompt_end = self.tokenizer.convert_tokens_to_ids("<|TEXT_PROMPT_END|>")
260
-
261
- input_ids = self.tokenizer.encode(input_text, add_special_tokens=False)
262
- chat = """user: Convert the text to speech:<|TEXT_REPLACE|>\nassistant:<|SPEECH_REPLACE|>"""
263
- ids = self.tokenizer.encode(chat)
264
-
265
- text_replace_idx = ids.index(text_replace)
266
- ids = (
267
- ids[:text_replace_idx]
268
- + [text_prompt_start]
269
- + input_ids
270
- + [text_prompt_end]
271
- + ids[text_replace_idx + 1 :] # noqa
272
- )
273
-
274
- speech_replace_idx = ids.index(speech_replace)
275
- codes_str = "".join([f"<|speech_{i}|>" for i in ref_codes])
276
- codes = self.tokenizer.encode(codes_str, add_special_tokens=False)
277
- ids = ids[:speech_replace_idx] + [speech_gen_start] + list(codes)
278
-
279
- return ids
280
-
281
- def _infer_torch(self, prompt_ids: list[int]) -> str:
282
- prompt_tensor = torch.tensor(prompt_ids).unsqueeze(0).to(self.backbone.device)
283
- speech_end_id = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>")
284
- with torch.no_grad():
285
- output_tokens = self.backbone.generate(
286
- prompt_tensor,
287
- max_length=self.max_context,
288
- eos_token_id=speech_end_id,
289
- do_sample=True,
290
- temperature=1.0,
291
- top_k=50,
292
- use_cache=True,
293
- min_new_tokens=50,
294
- )
295
- input_length = prompt_tensor.shape[-1]
296
- output_str = self.tokenizer.decode(
297
- output_tokens[0, input_length:].cpu().numpy().tolist(), add_special_tokens=False
298
- )
299
- return output_str
300
-
301
- def _infer_ggml(self, ref_codes: list[int], ref_text: str, input_text: str) -> str:
302
- ref_text = phonemize_with_dict(ref_text)
303
- input_text = phonemize_with_dict(input_text)
304
-
305
- codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes])
306
- prompt = (
307
- f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text} {input_text}"
308
- f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}"
309
- )
310
- output = self.backbone(
311
- prompt,
312
- max_tokens=self.max_context,
313
- temperature=1.0,
314
- top_k=50,
315
- stop=["<|SPEECH_GENERATION_END|>"],
316
- )
317
- output_str = output["choices"][0]["text"]
318
- return output_str
319
-
320
- def _infer_stream_ggml(self, ref_codes: torch.Tensor, ref_text: str, input_text: str) -> Generator[np.ndarray, None, None]:
321
- ref_text = phonemize_with_dict(ref_text)
322
- input_text = phonemize_with_dict(input_text)
323
-
324
- codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes])
325
- prompt = (
326
- f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text} {input_text}"
327
- f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}"
328
- )
329
-
330
- audio_cache: list[np.ndarray] = []
331
- token_cache: list[str] = [f"<|speech_{idx}|>" for idx in ref_codes]
332
- n_decoded_samples: int = 0
333
- n_decoded_tokens: int = len(ref_codes)
334
-
335
- for item in self.backbone(
336
- prompt,
337
- max_tokens=self.max_context,
338
- temperature=1.0,
339
- top_k=50,
340
- stop=["<|SPEECH_GENERATION_END|>"],
341
- stream=True
342
- ):
343
- output_str = item["choices"][0]["text"]
344
- token_cache.append(output_str)
345
-
346
- if len(token_cache[n_decoded_tokens:]) >= self.streaming_frames_per_chunk + self.streaming_lookforward:
347
-
348
- # decode chunk
349
- tokens_start = max(
350
- n_decoded_tokens
351
- - self.streaming_lookback
352
- - self.streaming_overlap_frames,
353
- 0
354
- )
355
- tokens_end = (
356
- n_decoded_tokens
357
- + self.streaming_frames_per_chunk
358
- + self.streaming_lookforward
359
- + self.streaming_overlap_frames
360
- )
361
- sample_start = (
362
- n_decoded_tokens - tokens_start
363
- ) * self.hop_length
364
- sample_end = (
365
- sample_start
366
- + (self.streaming_frames_per_chunk + 2 * self.streaming_overlap_frames) * self.hop_length
367
- )
368
- curr_codes = token_cache[tokens_start:tokens_end]
369
- recon = self._decode("".join(curr_codes))
370
- recon = recon[sample_start:sample_end]
371
- audio_cache.append(recon)
372
-
373
- # postprocess
374
- processed_recon = _linear_overlap_add(
375
- audio_cache, stride=self.streaming_stride_samples
376
- )
377
- new_samples_end = len(audio_cache) * self.streaming_stride_samples
378
- processed_recon = processed_recon[
379
- n_decoded_samples:new_samples_end
380
- ]
381
- n_decoded_samples = new_samples_end
382
- n_decoded_tokens += self.streaming_frames_per_chunk
383
- yield processed_recon
384
-
385
- # final decoding handled separately as non-constant chunk size
386
- remaining_tokens = len(token_cache) - n_decoded_tokens
387
- if len(token_cache) > n_decoded_tokens:
388
- tokens_start = max(
389
- len(token_cache)
390
- - (self.streaming_lookback + self.streaming_overlap_frames + remaining_tokens),
391
- 0
392
- )
393
- sample_start = (
394
- len(token_cache)
395
- - tokens_start
396
- - remaining_tokens
397
- - self.streaming_overlap_frames
398
- ) * self.hop_length
399
- curr_codes = token_cache[tokens_start:]
400
- recon = self._decode("".join(curr_codes))
401
- recon = recon[sample_start:]
402
- audio_cache.append(recon)
403
-
404
- processed_recon = _linear_overlap_add(audio_cache, stride=self.streaming_stride_samples)
405
- processed_recon = processed_recon[n_decoded_samples:]
406
- yield processed_recon
407
-
408
-
409
- # ============================================================================
410
- # FastVieNeuTTS - GPU-optimized implementation
411
- # Requires: LMDeploy with CUDA
412
- # ============================================================================
413
-
414
- class FastVieNeuTTS:
415
- """
416
- GPU-optimized VieNeu-TTS using LMDeploy TurbomindEngine.
417
- """
418
-
419
- def __init__(
420
- self,
421
- backbone_repo="pnnbao-ump/VieNeu-TTS",
422
- backbone_device="cuda",
423
- codec_repo="neuphonic/neucodec",
424
- codec_device="cuda",
425
- memory_util=0.3,
426
- tp=1,
427
- enable_prefix_caching=True,
428
- quant_policy=8,
429
- enable_triton=True,
430
- max_batch_size=8,
431
- ):
432
- """
433
- Initialize FastVieNeuTTS with LMDeploy backend and optimizations.
434
-
435
- Args:
436
- backbone_repo: Model repository
437
- backbone_device: Device for backbone (must be CUDA)
438
- codec_repo: Codec repository
439
- codec_device: Device for codec
440
- memory_util: GPU memory utilization (0.0-1.0)
441
- tp: Tensor parallel size for multi-GPU
442
- enable_prefix_caching: Enable prefix caching for faster batch processing
443
- quant_policy: KV cache quantization (0=off, 8=int8, 4=int4)
444
- enable_triton: Enable Triton compilation for codec
445
- max_batch_size: Maximum batch size for inference (prevent GPU overload)
446
- """
447
-
448
- if backbone_device != "cuda" and not backbone_device.startswith("cuda:"):
449
- raise ValueError("LMDeploy backend requires CUDA device")
450
-
451
- # Constants
452
- self.sample_rate = 24_000
453
- self.max_context = 2048
454
- self.hop_length = 480
455
- self.streaming_overlap_frames = 1
456
- self.streaming_frames_per_chunk = 50
457
- self.streaming_lookforward = 5
458
- self.streaming_lookback = 50
459
- self.streaming_stride_samples = self.streaming_frames_per_chunk * self.hop_length
460
-
461
- self.max_batch_size = max_batch_size
462
-
463
- self._ref_cache = {}
464
-
465
- self.stored_dict = defaultdict(dict)
466
-
467
- # Flags
468
- self._is_onnx_codec = False
469
- self._triton_enabled = False
470
-
471
- # Load models
472
- self._load_backbone_lmdeploy(backbone_repo, memory_util, tp, enable_prefix_caching, quant_policy)
473
- self._load_codec(codec_repo, codec_device, enable_triton)
474
-
475
- self._warmup_model()
476
-
477
- print("✅ FastVieNeuTTS with optimizations loaded successfully!")
478
- print(f" Max batch size: {self.max_batch_size} (adjustable to prevent GPU overload)")
479
-
480
- def _load_backbone_lmdeploy(self, repo, memory_util, tp, enable_prefix_caching, quant_policy):
481
- """Load backbone using LMDeploy's TurbomindEngine"""
482
- print(f"Loading backbone with LMDeploy from: {repo}")
483
-
484
- try:
485
- from lmdeploy import pipeline, TurbomindEngineConfig, GenerationConfig
486
- except ImportError as e:
487
- raise ImportError(
488
- "Failed to import `lmdeploy`. "
489
- "Please install it with: pip install lmdeploy"
490
- ) from e
491
-
492
- backend_config = TurbomindEngineConfig(
493
- cache_max_entry_count=memory_util,
494
- tp=tp,
495
- enable_prefix_caching=enable_prefix_caching,
496
- dtype='bfloat16',
497
- quant_policy=quant_policy
498
- )
499
-
500
- self.backbone = pipeline(repo, backend_config=backend_config)
501
-
502
- self.gen_config = GenerationConfig(
503
- top_p=0.95,
504
- top_k=50,
505
- temperature=1.0,
506
- max_new_tokens=1024,
507
- repetition_penalty=1.0,
508
- do_sample=True,
509
- min_new_tokens=40,
510
- min_p=0.1,
511
- )
512
-
513
- print(f" LMDeploy TurbomindEngine initialized")
514
- print(f" - Memory util: {memory_util}")
515
- print(f" - Tensor Parallel: {tp}")
516
- print(f" - Prefix caching: {enable_prefix_caching}")
517
- print(f" - KV quant: {quant_policy} ({'Enabled' if quant_policy > 0 else 'Disabled'})")
518
-
519
- def _load_codec(self, codec_repo, codec_device, enable_triton):
520
- """Load codec with optional Triton compilation"""
521
- print(f"Loading codec from: {codec_repo} on {codec_device}")
522
-
523
- match codec_repo:
524
- case "neuphonic/neucodec":
525
- self.codec = NeuCodec.from_pretrained(codec_repo)
526
- self.codec.eval().to(codec_device)
527
- case "neuphonic/distill-neucodec":
528
- self.codec = DistillNeuCodec.from_pretrained(codec_repo)
529
- self.codec.eval().to(codec_device)
530
- case "neuphonic/neucodec-onnx-decoder":
531
- if codec_device != "cpu":
532
- raise ValueError("ONNX decoder only runs on CPU")
533
- try:
534
- from neucodec import NeuCodecOnnxDecoder
535
- except ImportError as e:
536
- raise ImportError(
537
- "Failed to import ONNX decoder. "
538
- "Ensure onnxruntime and neucodec >= 0.0.4 are installed."
539
- ) from e
540
- self.codec = NeuCodecOnnxDecoder.from_pretrained(codec_repo)
541
- self._is_onnx_codec = True
542
- case _:
543
- raise ValueError(f"Unsupported codec repository: {codec_repo}")
544
-
545
- if enable_triton and not self._is_onnx_codec and codec_device != "cpu":
546
- self._triton_enabled = _compile_codec_with_triton(self.codec)
547
-
548
- def _warmup_model(self):
549
- """Warmup inference pipeline to reduce first-token latency"""
550
- print("🔥 Warming up model...")
551
- try:
552
- dummy_codes = list(range(10))
553
- dummy_prompt = self._format_prompt(dummy_codes, "warmup", "test")
554
- _ = self.backbone([dummy_prompt], gen_config=self.gen_config, do_preprocess=False)
555
- print(" ✅ Warmup complete")
556
- except Exception as e:
557
- print(f" ⚠️ Warmup failed (non-critical): {e}")
558
-
559
- def encode_reference(self, ref_audio_path: str | Path):
560
- """Encode reference audio to codes"""
561
- wav, _ = librosa.load(ref_audio_path, sr=16000, mono=True)
562
- wav_tensor = torch.from_numpy(wav).float().unsqueeze(0).unsqueeze(0)
563
- with torch.no_grad():
564
- ref_codes = self.codec.encode_code(audio_or_path=wav_tensor).squeeze(0).squeeze(0)
565
- return ref_codes
566
-
567
- def get_cached_reference(self, voice_name: str, audio_path: str, ref_text: str = None):
568
- """
569
- Get or create cached reference codes.
570
-
571
- Args:
572
- voice_name: Unique identifier for this voice
573
- audio_path: Path to reference audio
574
- ref_text: Optional reference text (stored with codes)
575
-
576
- Returns:
577
- ref_codes: Encoded reference codes
578
- """
579
- cache_key = f"{voice_name}_{audio_path}"
580
-
581
- if cache_key not in self._ref_cache:
582
- ref_codes = self.encode_reference(audio_path)
583
- self._ref_cache[cache_key] = {
584
- 'codes': ref_codes,
585
- 'ref_text': ref_text
586
- }
587
-
588
- return self._ref_cache[cache_key]['codes']
589
-
590
- def add_speaker(self, user_id: int, audio_file: str, ref_text: str):
591
- """
592
- Add a speaker to the stored dictionary for easy access.
593
-
594
- Args:
595
- user_id: Unique user ID
596
- audio_file: Reference audio file path
597
- ref_text: Reference text
598
-
599
- Returns:
600
- user_id: The user ID for use in streaming
601
- """
602
- codes = self.encode_reference(audio_file)
603
-
604
- if isinstance(codes, torch.Tensor):
605
- codes = codes.cpu().numpy()
606
- if isinstance(codes, np.ndarray):
607
- codes = codes.flatten().tolist()
608
-
609
- self.stored_dict[f"{user_id}"]['codes'] = codes
610
- self.stored_dict[f"{user_id}"]['ref_text'] = ref_text
611
-
612
- return user_id
613
-
614
- def _decode(self, codes: str):
615
- """Decode speech tokens to audio waveform"""
616
- speech_ids = [int(num) for num in re.findall(r"<\|speech_(\d+)\|>", codes)]
617
-
618
- if len(speech_ids) == 0:
619
- raise ValueError("No valid speech tokens found in output")
620
-
621
- if self._is_onnx_codec:
622
- codes = np.array(speech_ids, dtype=np.int32)[np.newaxis, np.newaxis, :]
623
- recon = self.codec.decode_code(codes)
624
- else:
625
- with torch.no_grad():
626
- codes = torch.tensor(speech_ids, dtype=torch.long)[None, None, :].to(
627
- self.codec.device
628
- )
629
- recon = self.codec.decode_code(codes).cpu().numpy()
630
-
631
- return recon[0, 0, :]
632
-
633
- def _decode_batch(self, codes_list: list[str], max_workers: int = None):
634
- """
635
- Decode multiple code strings in parallel.
636
-
637
- Args:
638
- codes_list: List of code strings to decode
639
- max_workers: Number of parallel workers (auto-tuned if None)
640
-
641
- Returns:
642
- List of decoded audio arrays
643
- """
644
- # Auto-tune workers based on GPU memory and batch size
645
- if max_workers is None:
646
- if torch.cuda.is_available():
647
- gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
648
- # 1 worker per 4GB VRAM, max 4 workers
649
- max_workers = min(max(1, int(gpu_mem_gb / 4)), 4)
650
- else:
651
- max_workers = 2
652
-
653
- # For small batches, use sequential to avoid overhead
654
- if len(codes_list) <= 2:
655
- return [self._decode(codes) for codes in codes_list]
656
-
657
- # Parallel decoding with controlled workers
658
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
659
- futures = [executor.submit(self._decode, codes) for codes in codes_list]
660
- results = [f.result() for f in futures]
661
- return results
662
-
663
- def _format_prompt(self, ref_codes: list[int], ref_text: str, input_text: str) -> str:
664
- """Format prompt for LMDeploy"""
665
- ref_text_phones = phonemize_with_dict(ref_text)
666
- input_text_phones = phonemize_with_dict(input_text)
667
-
668
- codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes])
669
-
670
- prompt = (
671
- f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text_phones} {input_text_phones}"
672
- f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}"
673
- )
674
-
675
- return prompt
676
-
677
- def infer(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> np.ndarray:
678
- """
679
- Single inference.
680
-
681
- Args:
682
- text: Input text to synthesize
683
- ref_codes: Encoded reference audio codes
684
- ref_text: Reference text for reference audio
685
-
686
- Returns:
687
- Generated speech waveform as numpy array
688
- """
689
- if isinstance(ref_codes, torch.Tensor):
690
- ref_codes = ref_codes.cpu().numpy()
691
- if isinstance(ref_codes, np.ndarray):
692
- ref_codes = ref_codes.flatten().tolist()
693
-
694
- prompt = self._format_prompt(ref_codes, ref_text, text)
695
-
696
- # Use LMDeploy pipeline for generation
697
- responses = self.backbone([prompt], gen_config=self.gen_config, do_preprocess=False)
698
- output_str = responses[0].text
699
-
700
- # Decode to audio
701
- wav = self._decode(output_str)
702
-
703
- return wav
704
-
705
- def infer_batch(self, texts: list[str], ref_codes: np.ndarray | torch.Tensor, ref_text: str, max_batch_size: int = None) -> list[np.ndarray]:
706
- """
707
- Batch inference for multiple texts.
708
-
709
- Args:
710
- texts: List of input texts to synthesize
711
- ref_codes: Encoded reference audio codes
712
- ref_text: Reference text for reference audio
713
- max_batch_size: Maximum chunks to process at once (prevent GPU overload)
714
-
715
- Returns:
716
- List of generated speech waveforms
717
- """
718
- if max_batch_size is None:
719
- max_batch_size = self.max_batch_size
720
-
721
- if not isinstance(texts, list):
722
- texts = [texts]
723
-
724
- if isinstance(ref_codes, torch.Tensor):
725
- ref_codes = ref_codes.cpu().numpy()
726
- if isinstance(ref_codes, np.ndarray):
727
- ref_codes = ref_codes.flatten().tolist()
728
-
729
- all_wavs = []
730
-
731
- # Process in smaller batches to avoid GPU OOM
732
- for i in range(0, len(texts), max_batch_size):
733
- batch_texts = texts[i:i+max_batch_size]
734
-
735
- # Format prompts for this batch
736
- prompts = [self._format_prompt(ref_codes, ref_text, text) for text in batch_texts]
737
-
738
- # Batch generation with LMDeploy
739
- responses = self.backbone(prompts, gen_config=self.gen_config, do_preprocess=False)
740
-
741
- # Decode outputs (with smart parallelization)
742
- batch_codes = [response.text for response in responses]
743
-
744
- # Auto-tune parallel workers based on batch size
745
- if len(batch_codes) > 3:
746
- batch_wavs = self._decode_batch(batch_codes)
747
- else:
748
- # Sequential for small batches (less overhead)
749
- batch_wavs = [self._decode(codes) for codes in batch_codes]
750
-
751
- all_wavs.extend(batch_wavs)
752
-
753
- # Clean up memory between batches
754
- if i + max_batch_size < len(texts):
755
- if torch.cuda.is_available():
756
- torch.cuda.empty_cache()
757
-
758
- return all_wavs
759
-
760
- def infer_stream(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> Generator[np.ndarray, None, None]:
761
- """
762
- Streaming inference with low latency.
763
-
764
- Args:
765
- text: Input text to synthesize
766
- ref_codes: Encoded reference audio codes
767
- ref_text: Reference text for reference audio
768
-
769
- Yields:
770
- Audio chunks as numpy arrays
771
- """
772
- if isinstance(ref_codes, torch.Tensor):
773
- ref_codes = ref_codes.cpu().numpy()
774
- if isinstance(ref_codes, np.ndarray):
775
- ref_codes = ref_codes.flatten().tolist()
776
-
777
- prompt = self._format_prompt(ref_codes, ref_text, text)
778
-
779
- audio_cache = []
780
- token_cache = [f"<|speech_{idx}|>" for idx in ref_codes]
781
- n_decoded_samples = 0
782
- n_decoded_tokens = len(ref_codes)
783
-
784
- for response in self.backbone.stream_infer([prompt], gen_config=self.gen_config, do_preprocess=False):
785
- output_str = response.text
786
-
787
- # Extract new tokens
788
- new_tokens = output_str[len("".join(token_cache[len(ref_codes):])):] if len(token_cache) > len(ref_codes) else output_str
789
-
790
- if new_tokens:
791
- token_cache.append(new_tokens)
792
-
793
- # Check if we have enough tokens to decode a chunk
794
- if len(token_cache[n_decoded_tokens:]) >= self.streaming_frames_per_chunk + self.streaming_lookforward:
795
-
796
- # Decode chunk with context
797
- tokens_start = max(
798
- n_decoded_tokens - self.streaming_lookback - self.streaming_overlap_frames,
799
- 0
800
- )
801
- tokens_end = (
802
- n_decoded_tokens
803
- + self.streaming_frames_per_chunk
804
- + self.streaming_lookforward
805
- + self.streaming_overlap_frames
806
- )
807
- sample_start = (n_decoded_tokens - tokens_start) * self.hop_length
808
- sample_end = (
809
- sample_start
810
- + (self.streaming_frames_per_chunk + 2 * self.streaming_overlap_frames) * self.hop_length
811
- )
812
-
813
- curr_codes = token_cache[tokens_start:tokens_end]
814
- recon = self._decode("".join(curr_codes))
815
- recon = recon[sample_start:sample_end]
816
- audio_cache.append(recon)
817
-
818
- # Overlap-add processing
819
- processed_recon = _linear_overlap_add(
820
- audio_cache, stride=self.streaming_stride_samples
821
- )
822
- new_samples_end = len(audio_cache) * self.streaming_stride_samples
823
- processed_recon = processed_recon[n_decoded_samples:new_samples_end]
824
- n_decoded_samples = new_samples_end
825
- n_decoded_tokens += self.streaming_frames_per_chunk
826
-
827
- yield processed_recon
828
-
829
- # Final chunk
830
- remaining_tokens = len(token_cache) - n_decoded_tokens
831
- if remaining_tokens > 0:
832
- tokens_start = max(
833
- len(token_cache) - (self.streaming_lookback + self.streaming_overlap_frames + remaining_tokens),
834
- 0
835
- )
836
- sample_start = (
837
- len(token_cache) - tokens_start - remaining_tokens - self.streaming_overlap_frames
838
- ) * self.hop_length
839
-
840
- curr_codes = token_cache[tokens_start:]
841
- recon = self._decode("".join(curr_codes))
842
- recon = recon[sample_start:]
843
- audio_cache.append(recon)
844
-
845
- processed_recon = _linear_overlap_add(audio_cache, stride=self.streaming_stride_samples)
846
- processed_recon = processed_recon[n_decoded_samples:]
847
- yield processed_recon
848
-
849
- def cleanup_memory(self):
850
- """Clean up GPU memory"""
851
- if torch.cuda.is_available():
852
- torch.cuda.empty_cache()
853
- gc.collect()
854
- print("🧹 Memory cleaned up")
855
-
856
- def get_optimization_stats(self) -> dict:
857
- """
858
- Get current optimization statistics.
859
-
860
- Returns:
861
- Dictionary with optimization info
862
- """
863
- return {
864
- 'triton_enabled': self._triton_enabled,
865
- 'cached_references': len(self._ref_cache),
866
- 'active_sessions': len(self.stored_dict),
867
- 'kv_quant': self.gen_config.__dict__.get('quant_policy', 0),
868
- 'prefix_caching': True, # Always enabled in our config
869
- }