nhantrungsp commited on
Commit
ecb8409
·
verified ·
1 Parent(s): 96cf6e9

Upload 2 files

Browse files
Files changed (2) hide show
  1. vieneu_tts/__init__.py +4 -0
  2. vieneu_tts/vieneu_tts.py +385 -0
vieneu_tts/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .vieneu_tts import VieNeuTTS
2
+
3
+ __all__ = ["VieNeuTTS"]
4
+
vieneu_tts/vieneu_tts.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Generator
3
+ import librosa
4
+ import numpy as np
5
+ import torch
6
+ from neucodec import NeuCodec, DistillNeuCodec
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+ from utils.phonemize_text import phonemize_text, phonemize_with_dict
9
+ import re
10
+
11
+ def _linear_overlap_add(frames: list[np.ndarray], stride: int) -> np.ndarray:
12
+ # original impl --> https://github.com/facebookresearch/encodec/blob/main/encodec/utils.py
13
+ assert len(frames)
14
+ dtype = frames[0].dtype
15
+ shape = frames[0].shape[:-1]
16
+
17
+ total_size = 0
18
+ for i, frame in enumerate(frames):
19
+ frame_end = stride * i + frame.shape[-1]
20
+ total_size = max(total_size, frame_end)
21
+
22
+ sum_weight = np.zeros(total_size, dtype=dtype)
23
+ out = np.zeros(*shape, total_size, dtype=dtype)
24
+
25
+ offset: int = 0
26
+ for frame in frames:
27
+ frame_length = frame.shape[-1]
28
+ t = np.linspace(0, 1, frame_length + 2, dtype=dtype)[1:-1]
29
+ weight = np.abs(0.5 - (t - 0.5))
30
+
31
+ out[..., offset : offset + frame_length] += weight * frame
32
+ sum_weight[offset : offset + frame_length] += weight
33
+ offset += stride
34
+ assert sum_weight.min() > 0
35
+ return out / sum_weight
36
+
37
+ class VieNeuTTS:
38
+ def __init__(
39
+ self,
40
+ backbone_repo="pnnbao-ump/VieNeu-TTS",
41
+ backbone_device="cpu",
42
+ codec_repo="neuphonic/neucodec",
43
+ codec_device="cpu",
44
+ ):
45
+
46
+ # Constants
47
+ self.sample_rate = 24_000
48
+ self.max_context = 2048
49
+ self.hop_length = 480
50
+ self.streaming_overlap_frames = 1
51
+ self.streaming_frames_per_chunk = 25
52
+ self.streaming_lookforward = 5
53
+ self.streaming_lookback = 50
54
+ self.streaming_stride_samples = self.streaming_frames_per_chunk * self.hop_length
55
+
56
+ # ggml & onnx flags
57
+ self._is_quantized_model = False
58
+ self._is_onnx_codec = False
59
+
60
+ # HF tokenizer
61
+ self.tokenizer = None
62
+
63
+ # Load models
64
+ self._load_backbone(backbone_repo, backbone_device)
65
+ self._load_codec(codec_repo, codec_device)
66
+
67
+ def _load_backbone(self, backbone_repo, backbone_device):
68
+ print(f"Loading backbone from: {backbone_repo} on {backbone_device} ...")
69
+
70
+ if backbone_repo.lower().endswith("gguf") or "gguf" in backbone_repo.lower():
71
+ try:
72
+ from llama_cpp import Llama
73
+ except ImportError as e:
74
+ raise ImportError(
75
+ "Failed to import `llama_cpp`. "
76
+ "Please install it with:\n"
77
+ " pip install llama-cpp-python"
78
+ ) from e
79
+ self.backbone = Llama.from_pretrained(
80
+ repo_id=backbone_repo,
81
+ filename="*.gguf",
82
+ verbose=False,
83
+ n_gpu_layers=-1 if backbone_device == "gpu" else 0,
84
+ n_ctx=self.max_context,
85
+ mlock=True,
86
+ flash_attn=True if backbone_device == "gpu" else False,
87
+ )
88
+ self._is_quantized_model = True
89
+
90
+ else:
91
+ self.tokenizer = AutoTokenizer.from_pretrained(backbone_repo)
92
+ print(f" Loading model to device: {backbone_device}")
93
+
94
+ print(f" 📦 Loading with FP32 (stable mode)")
95
+ self.backbone = AutoModelForCausalLM.from_pretrained(backbone_repo)
96
+
97
+ print(f" Model loaded, moving to {backbone_device}...")
98
+ self.backbone = self.backbone.to(torch.device(backbone_device))
99
+ print(f" ✓ Backbone on device: {next(self.backbone.parameters()).device}")
100
+ print(f" ✓ Backbone dtype: {next(self.backbone.parameters()).dtype}")
101
+
102
+ def _load_codec(self, codec_repo, codec_device):
103
+ print(f"Loading codec from: {codec_repo} on {codec_device} ...")
104
+ match codec_repo:
105
+ case "neuphonic/neucodec":
106
+ self.codec = NeuCodec.from_pretrained(codec_repo)
107
+
108
+ # Keep codec in FP32 for compatibility with feature_extractor
109
+ # Only backbone uses FP16
110
+ print(f" 📦 Keeping codec in FP32 (compatibility)")
111
+
112
+ self.codec.eval().to(codec_device)
113
+ print(f" ✓ Codec on device: {next(self.codec.parameters()).device}")
114
+ print(f" ✓ Codec dtype: {next(self.codec.parameters()).dtype}")
115
+ case "neuphonic/distill-neucodec":
116
+ self.codec = DistillNeuCodec.from_pretrained(codec_repo)
117
+
118
+ # Keep distill-codec in FP32 for compatibility
119
+ print(f" 📦 Keeping distill-codec in FP32 (compatibility)")
120
+
121
+ self.codec.eval().to(codec_device)
122
+ print(f" ✓ Distill-Codec on device: {next(self.codec.parameters()).device}")
123
+ print(f" ✓ Distill-Codec dtype: {next(self.codec.parameters()).dtype}")
124
+ case "neuphonic/neucodec-onnx-decoder":
125
+ if codec_device != "cpu":
126
+ raise ValueError("Onnx decoder only currently runs on CPU.")
127
+ try:
128
+ from neucodec import NeuCodecOnnxDecoder
129
+ except ImportError as e:
130
+ raise ImportError(
131
+ "Failed to import the onnx decoder."
132
+ " Ensure you have onnxruntime installed as well as neucodec >= 0.0.4."
133
+ ) from e
134
+ self.codec = NeuCodecOnnxDecoder.from_pretrained(codec_repo)
135
+ self._is_onnx_codec = True
136
+ case _:
137
+ raise ValueError(f"Unsupported codec repository: {codec_repo}")
138
+
139
+ def infer(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> np.ndarray:
140
+ """
141
+ Perform inference to generate speech from text using the TTS model and reference audio.
142
+
143
+ Args:
144
+ text (str): Input text to be converted to speech.
145
+ ref_codes (np.ndarray | torch.tensor): Encoded reference.
146
+ ref_text (str): Reference text for reference audio. Defaults to None.
147
+ Returns:
148
+ np.ndarray: Generated speech waveform.
149
+ """
150
+
151
+ # Generate tokens
152
+ if self._is_quantized_model:
153
+ output_str = self._infer_ggml(ref_codes, ref_text, text)
154
+ else:
155
+ prompt_ids = self._apply_chat_template(ref_codes, ref_text, text)
156
+ output_str = self._infer_torch(prompt_ids)
157
+
158
+ # Decode
159
+ wav = self._decode(output_str)
160
+
161
+ return wav
162
+
163
+ def infer_stream(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> Generator[np.ndarray, None, None]:
164
+ """
165
+ Perform streaming inference to generate speech from text using the TTS model and reference audio.
166
+
167
+ Args:
168
+ text (str): Input text to be converted to speech.
169
+ ref_codes (np.ndarray | torch.tensor): Encoded reference.
170
+ ref_text (str): Reference text for reference audio. Defaults to None.
171
+ Yields:
172
+ np.ndarray: Generated speech waveform.
173
+ """
174
+
175
+ if self._is_quantized_model:
176
+ return self._infer_stream_ggml(ref_codes, ref_text, text)
177
+ else:
178
+ raise NotImplementedError("Streaming is not implemented for the torch backend!")
179
+
180
+ def encode_reference(self, ref_audio_path: str | Path):
181
+ wav, _ = librosa.load(ref_audio_path, sr=16000, mono=True)
182
+ wav_tensor = torch.from_numpy(wav).float().unsqueeze(0).unsqueeze(0) # [1, 1, T]
183
+
184
+ # NeuCodec expects CPU tensor for encode_code
185
+ wav_tensor_cpu = wav_tensor.cpu().float()
186
+
187
+ with torch.no_grad():
188
+ ref_codes = self.codec.encode_code(audio_or_path=wav_tensor_cpu).squeeze(0).squeeze(0)
189
+
190
+ # Ensure result is on CPU for caching
191
+ if ref_codes.device.type != 'cpu':
192
+ ref_codes = ref_codes.cpu()
193
+
194
+ return ref_codes
195
+
196
+ def _decode(self, codes: str):
197
+ """Decode speech tokens to audio waveform."""
198
+ # Extract speech token IDs using regex
199
+ speech_ids = [int(num) for num in re.findall(r"<\|speech_(\d+)\|>", codes)]
200
+
201
+ if len(speech_ids) == 0:
202
+ raise ValueError(
203
+ "No valid speech tokens found in the output. "
204
+ "The model may not have generated proper speech tokens."
205
+ )
206
+
207
+ # Onnx decode
208
+ if self._is_onnx_codec:
209
+ codes = np.array(speech_ids, dtype=np.int32)[np.newaxis, np.newaxis, :]
210
+ recon = self.codec.decode_code(codes)
211
+ # Torch decode
212
+ else:
213
+ with torch.no_grad():
214
+ codes = torch.tensor(speech_ids, dtype=torch.long)[None, None, :].to(
215
+ self.codec.device
216
+ )
217
+
218
+ # Codec is kept in FP32, no need for autocast
219
+ recon = self.codec.decode_code(codes).cpu().numpy()
220
+
221
+ return recon[0, 0, :]
222
+
223
+ def _apply_chat_template(self, ref_codes: list[int], ref_text: str, input_text: str) -> list[int]:
224
+ # Convert ref_codes to list if it's a tensor
225
+ if hasattr(ref_codes, 'cpu'):
226
+ ref_codes = ref_codes.cpu().numpy().tolist()
227
+ elif hasattr(ref_codes, 'tolist'):
228
+ ref_codes = ref_codes.tolist()
229
+
230
+ input_text = phonemize_with_dict(ref_text) + " " + phonemize_with_dict(input_text)
231
+
232
+ speech_replace = self.tokenizer.convert_tokens_to_ids("<|SPEECH_REPLACE|>")
233
+ speech_gen_start = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_START|>")
234
+ text_replace = self.tokenizer.convert_tokens_to_ids("<|TEXT_REPLACE|>")
235
+ text_prompt_start = self.tokenizer.convert_tokens_to_ids("<|TEXT_PROMPT_START|>")
236
+ text_prompt_end = self.tokenizer.convert_tokens_to_ids("<|TEXT_PROMPT_END|>")
237
+
238
+ input_ids = self.tokenizer.encode(input_text, add_special_tokens=False)
239
+ chat = """user: Convert the text to speech:<|TEXT_REPLACE|>\nassistant:<|SPEECH_REPLACE|>"""
240
+ ids = self.tokenizer.encode(chat)
241
+
242
+ text_replace_idx = ids.index(text_replace)
243
+ ids = (
244
+ ids[:text_replace_idx]
245
+ + [text_prompt_start]
246
+ + input_ids
247
+ + [text_prompt_end]
248
+ + ids[text_replace_idx + 1 :] # noqa
249
+ )
250
+
251
+ speech_replace_idx = ids.index(speech_replace)
252
+ codes_str = "".join([f"<|speech_{i}|>" for i in ref_codes])
253
+ codes = self.tokenizer.encode(codes_str, add_special_tokens=False)
254
+ ids = ids[:speech_replace_idx] + [speech_gen_start] + list(codes)
255
+
256
+ return ids
257
+
258
+ def _infer_torch(self, prompt_ids: list[int]) -> str:
259
+ prompt_tensor = torch.tensor(prompt_ids).unsqueeze(0).to(self.backbone.device)
260
+ speech_end_id = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>")
261
+
262
+ with torch.no_grad():
263
+ output_tokens = self.backbone.generate(
264
+ prompt_tensor,
265
+ max_length=self.max_context,
266
+ eos_token_id=speech_end_id,
267
+ do_sample=True,
268
+ temperature=1.0,
269
+ top_k=50,
270
+ use_cache=True,
271
+ min_new_tokens=50,
272
+ )
273
+
274
+ input_length = prompt_tensor.shape[-1]
275
+ output_str = self.tokenizer.decode(
276
+ output_tokens[0, input_length:].cpu().numpy().tolist(), add_special_tokens=False
277
+ )
278
+ return output_str
279
+
280
+ def _infer_ggml(self, ref_codes: list[int], ref_text: str, input_text: str) -> str:
281
+ ref_text = phonemize_with_dict(ref_text)
282
+ input_text = phonemize_with_dict(input_text)
283
+
284
+ codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes])
285
+ prompt = (
286
+ f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text} {input_text}"
287
+ f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}"
288
+ )
289
+ output = self.backbone(
290
+ prompt,
291
+ max_tokens=self.max_context,
292
+ temperature=1.0,
293
+ top_k=50,
294
+ stop=["<|SPEECH_GENERATION_END|>"],
295
+ )
296
+ output_str = output["choices"][0]["text"]
297
+ return output_str
298
+
299
+ def _infer_stream_ggml(self, ref_codes: torch.Tensor, ref_text: str, input_text: str) -> Generator[np.ndarray, None, None]:
300
+ ref_text = phonemize_with_dict(ref_text)
301
+ input_text = phonemize_with_dict(input_text)
302
+
303
+ codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes])
304
+ prompt = (
305
+ f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text} {input_text}"
306
+ f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}"
307
+ )
308
+
309
+ audio_cache: list[np.ndarray] = []
310
+ token_cache: list[str] = [f"<|speech_{idx}|>" for idx in ref_codes]
311
+ n_decoded_samples: int = 0
312
+ n_decoded_tokens: int = len(ref_codes)
313
+
314
+ for item in self.backbone(
315
+ prompt,
316
+ max_tokens=self.max_context,
317
+ temperature=0.2,
318
+ top_k=50,
319
+ stop=["<|SPEECH_GENERATION_END|>"],
320
+ stream=True
321
+ ):
322
+ output_str = item["choices"][0]["text"]
323
+ token_cache.append(output_str)
324
+
325
+ if len(token_cache[n_decoded_tokens:]) >= self.streaming_frames_per_chunk + self.streaming_lookforward:
326
+
327
+ # decode chunk
328
+ tokens_start = max(
329
+ n_decoded_tokens
330
+ - self.streaming_lookback
331
+ - self.streaming_overlap_frames,
332
+ 0
333
+ )
334
+ tokens_end = (
335
+ n_decoded_tokens
336
+ + self.streaming_frames_per_chunk
337
+ + self.streaming_lookforward
338
+ + self.streaming_overlap_frames
339
+ )
340
+ sample_start = (
341
+ n_decoded_tokens - tokens_start
342
+ ) * self.hop_length
343
+ sample_end = (
344
+ sample_start
345
+ + (self.streaming_frames_per_chunk + 2 * self.streaming_overlap_frames) * self.hop_length
346
+ )
347
+ curr_codes = token_cache[tokens_start:tokens_end]
348
+ recon = self._decode("".join(curr_codes))
349
+ recon = recon[sample_start:sample_end]
350
+ audio_cache.append(recon)
351
+
352
+ # postprocess
353
+ processed_recon = _linear_overlap_add(
354
+ audio_cache, stride=self.streaming_stride_samples
355
+ )
356
+ new_samples_end = len(audio_cache) * self.streaming_stride_samples
357
+ processed_recon = processed_recon[
358
+ n_decoded_samples:new_samples_end
359
+ ]
360
+ n_decoded_samples = new_samples_end
361
+ n_decoded_tokens += self.streaming_frames_per_chunk
362
+ yield processed_recon
363
+
364
+ # final decoding handled separately as non-constant chunk size
365
+ remaining_tokens = len(token_cache) - n_decoded_tokens
366
+ if len(token_cache) > n_decoded_tokens:
367
+ tokens_start = max(
368
+ len(token_cache)
369
+ - (self.streaming_lookback + self.streaming_overlap_frames + remaining_tokens),
370
+ 0
371
+ )
372
+ sample_start = (
373
+ len(token_cache)
374
+ - tokens_start
375
+ - remaining_tokens
376
+ - self.streaming_overlap_frames
377
+ ) * self.hop_length
378
+ curr_codes = token_cache[tokens_start:]
379
+ recon = self._decode("".join(curr_codes))
380
+ recon = recon[sample_start:]
381
+ audio_cache.append(recon)
382
+
383
+ processed_recon = _linear_overlap_add(audio_cache, stride=self.streaming_stride_samples)
384
+ processed_recon = processed_recon[n_decoded_samples:]
385
+ yield processed_recon