pnnbao-ump commited on
Commit
5d4607a
·
verified ·
1 Parent(s): 4754fab

Delete vieneu_tts.py

Browse files
Files changed (1) hide show
  1. vieneu_tts.py +0 -347
vieneu_tts.py DELETED
@@ -1,347 +0,0 @@
1
- from pathlib import Path
2
- from typing import Generator
3
- import librosa
4
- import numpy as np
5
- import torch
6
- from neucodec import NeuCodec, DistillNeuCodec
7
- from transformers import AutoTokenizer, AutoModelForCausalLM
8
- from utils.phonemize_text import phonemize_text, phonemize_with_dict
9
- import re
10
-
11
- def _linear_overlap_add(frames: list[np.ndarray], stride: int) -> np.ndarray:
12
- # original impl --> https://github.com/facebookresearch/encodec/blob/main/encodec/utils.py
13
- assert len(frames)
14
- dtype = frames[0].dtype
15
- shape = frames[0].shape[:-1]
16
-
17
- total_size = 0
18
- for i, frame in enumerate(frames):
19
- frame_end = stride * i + frame.shape[-1]
20
- total_size = max(total_size, frame_end)
21
-
22
- sum_weight = np.zeros(total_size, dtype=dtype)
23
- out = np.zeros(*shape, total_size, dtype=dtype)
24
-
25
- offset: int = 0
26
- for frame in frames:
27
- frame_length = frame.shape[-1]
28
- t = np.linspace(0, 1, frame_length + 2, dtype=dtype)[1:-1]
29
- weight = np.abs(0.5 - (t - 0.5))
30
-
31
- out[..., offset : offset + frame_length] += weight * frame
32
- sum_weight[offset : offset + frame_length] += weight
33
- offset += stride
34
- assert sum_weight.min() > 0
35
- return out / sum_weight
36
-
37
- class VieNeuTTS:
38
- def __init__(
39
- self,
40
- backbone_repo="pnnbao-ump/VieNeu-TTS",
41
- backbone_device="cpu",
42
- codec_repo="neuphonic/neucodec",
43
- codec_device="cpu",
44
- ):
45
-
46
- # Constants
47
- self.sample_rate = 24_000
48
- self.max_context = 2048
49
- self.hop_length = 480
50
- self.streaming_overlap_frames = 1
51
- self.streaming_frames_per_chunk = 25
52
- self.streaming_lookforward = 5
53
- self.streaming_lookback = 50
54
- self.streaming_stride_samples = self.streaming_frames_per_chunk * self.hop_length
55
-
56
- # ggml & onnx flags
57
- self._is_quantized_model = False
58
- self._is_onnx_codec = False
59
-
60
- # HF tokenizer
61
- self.tokenizer = None
62
-
63
- # Load models
64
- self._load_backbone(backbone_repo, backbone_device)
65
- self._load_codec(codec_repo, codec_device)
66
-
67
- def _load_backbone(self, backbone_repo, backbone_device):
68
- print(f"Loading backbone from: {backbone_repo} on {backbone_device} ...")
69
-
70
- if backbone_repo.lower().endswith("gguf") or "gguf" in backbone_repo.lower():
71
- try:
72
- from llama_cpp import Llama
73
- except ImportError as e:
74
- raise ImportError(
75
- "Failed to import `llama_cpp`. "
76
- "Please install it with:\n"
77
- " pip install llama-cpp-python"
78
- ) from e
79
- self.backbone = Llama.from_pretrained(
80
- repo_id=backbone_repo,
81
- filename="*.gguf",
82
- verbose=False,
83
- n_gpu_layers=-1 if backbone_device == "gpu" else 0,
84
- n_ctx=self.max_context,
85
- mlock=True,
86
- flash_attn=True if backbone_device == "gpu" else False,
87
- )
88
- self._is_quantized_model = True
89
-
90
- else:
91
- self.tokenizer = AutoTokenizer.from_pretrained(backbone_repo)
92
- self.backbone = AutoModelForCausalLM.from_pretrained(backbone_repo).to(
93
- torch.device(backbone_device)
94
- )
95
-
96
- def _load_codec(self, codec_repo, codec_device):
97
- print(f"Loading codec from: {codec_repo} on {codec_device} ...")
98
- match codec_repo:
99
- case "neuphonic/neucodec":
100
- self.codec = NeuCodec.from_pretrained(codec_repo)
101
- self.codec.eval().to(codec_device)
102
- case "neuphonic/distill-neucodec":
103
- self.codec = DistillNeuCodec.from_pretrained(codec_repo)
104
- self.codec.eval().to(codec_device)
105
- case "neuphonic/neucodec-onnx-decoder":
106
- if codec_device != "cpu":
107
- raise ValueError("Onnx decoder only currently runs on CPU.")
108
- try:
109
- from neucodec import NeuCodecOnnxDecoder
110
- except ImportError as e:
111
- raise ImportError(
112
- "Failed to import the onnx decoder."
113
- " Ensure you have onnxruntime installed as well as neucodec >= 0.0.4."
114
- ) from e
115
- self.codec = NeuCodecOnnxDecoder.from_pretrained(codec_repo)
116
- self._is_onnx_codec = True
117
- case _:
118
- raise ValueError(f"Unsupported codec repository: {codec_repo}")
119
-
120
- def infer(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> np.ndarray:
121
- """
122
- Perform inference to generate speech from text using the TTS model and reference audio.
123
-
124
- Args:
125
- text (str): Input text to be converted to speech.
126
- ref_codes (np.ndarray | torch.tensor): Encoded reference.
127
- ref_text (str): Reference text for reference audio. Defaults to None.
128
- Returns:
129
- np.ndarray: Generated speech waveform.
130
- """
131
-
132
- # Generate tokens
133
- if self._is_quantized_model:
134
- output_str = self._infer_ggml(ref_codes, ref_text, text)
135
- else:
136
- prompt_ids = self._apply_chat_template(ref_codes, ref_text, text)
137
- output_str = self._infer_torch(prompt_ids)
138
-
139
- # Decode
140
- wav = self._decode(output_str)
141
-
142
- return wav
143
-
144
- def infer_stream(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> Generator[np.ndarray, None, None]:
145
- """
146
- Perform streaming inference to generate speech from text using the TTS model and reference audio.
147
-
148
- Args:
149
- text (str): Input text to be converted to speech.
150
- ref_codes (np.ndarray | torch.tensor): Encoded reference.
151
- ref_text (str): Reference text for reference audio. Defaults to None.
152
- Yields:
153
- np.ndarray: Generated speech waveform.
154
- """
155
-
156
- if self._is_quantized_model:
157
- return self._infer_stream_ggml(ref_codes, ref_text, text)
158
- else:
159
- raise NotImplementedError("Streaming is not implemented for the torch backend!")
160
-
161
- def encode_reference(self, ref_audio_path: str | Path):
162
- wav, _ = librosa.load(ref_audio_path, sr=16000, mono=True)
163
- wav_tensor = torch.from_numpy(wav).float().unsqueeze(0).unsqueeze(0) # [1, 1, T]
164
- with torch.no_grad():
165
- ref_codes = self.codec.encode_code(audio_or_path=wav_tensor).squeeze(0).squeeze(0)
166
- return ref_codes
167
-
168
- def _decode(self, codes: str):
169
- """Decode speech tokens to audio waveform."""
170
- # Extract speech token IDs using regex
171
- speech_ids = [int(num) for num in re.findall(r"<\|speech_(\d+)\|>", codes)]
172
-
173
- if len(speech_ids) == 0:
174
- raise ValueError(
175
- "No valid speech tokens found in the output. "
176
- "The model may not have generated proper speech tokens."
177
- )
178
-
179
- # Onnx decode
180
- if self._is_onnx_codec:
181
- codes = np.array(speech_ids, dtype=np.int32)[np.newaxis, np.newaxis, :]
182
- recon = self.codec.decode_code(codes)
183
- # Torch decode
184
- else:
185
- with torch.no_grad():
186
- codes = torch.tensor(speech_ids, dtype=torch.long)[None, None, :].to(
187
- self.codec.device
188
- )
189
- recon = self.codec.decode_code(codes).cpu().numpy()
190
-
191
- return recon[0, 0, :]
192
-
193
- def _apply_chat_template(self, ref_codes: list[int], ref_text: str, input_text: str) -> list[int]:
194
- input_text = phonemize_with_dict(ref_text) + " " + phonemize_with_dict(input_text)
195
-
196
- speech_replace = self.tokenizer.convert_tokens_to_ids("<|SPEECH_REPLACE|>")
197
- speech_gen_start = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_START|>")
198
- text_replace = self.tokenizer.convert_tokens_to_ids("<|TEXT_REPLACE|>")
199
- text_prompt_start = self.tokenizer.convert_tokens_to_ids("<|TEXT_PROMPT_START|>")
200
- text_prompt_end = self.tokenizer.convert_tokens_to_ids("<|TEXT_PROMPT_END|>")
201
-
202
- input_ids = self.tokenizer.encode(input_text, add_special_tokens=False)
203
- chat = """user: Convert the text to speech:<|TEXT_REPLACE|>\nassistant:<|SPEECH_REPLACE|>"""
204
- ids = self.tokenizer.encode(chat)
205
-
206
- text_replace_idx = ids.index(text_replace)
207
- ids = (
208
- ids[:text_replace_idx]
209
- + [text_prompt_start]
210
- + input_ids
211
- + [text_prompt_end]
212
- + ids[text_replace_idx + 1 :] # noqa
213
- )
214
-
215
- speech_replace_idx = ids.index(speech_replace)
216
- codes_str = "".join([f"<|speech_{i}|>" for i in ref_codes])
217
- codes = self.tokenizer.encode(codes_str, add_special_tokens=False)
218
- ids = ids[:speech_replace_idx] + [speech_gen_start] + list(codes)
219
-
220
- return ids
221
-
222
- def _infer_torch(self, prompt_ids: list[int]) -> str:
223
- prompt_tensor = torch.tensor(prompt_ids).unsqueeze(0).to(self.backbone.device)
224
- speech_end_id = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>")
225
- with torch.no_grad():
226
- output_tokens = self.backbone.generate(
227
- prompt_tensor,
228
- max_length=self.max_context,
229
- eos_token_id=speech_end_id,
230
- do_sample=True,
231
- temperature=1,
232
- top_k=50,
233
- use_cache=True,
234
- min_new_tokens=50,
235
- )
236
- input_length = prompt_tensor.shape[-1]
237
- output_str = self.tokenizer.decode(
238
- output_tokens[0, input_length:].cpu().numpy().tolist(), add_special_tokens=False
239
- )
240
- return output_str
241
-
242
- def _infer_ggml(self, ref_codes: list[int], ref_text: str, input_text: str) -> str:
243
- ref_text = phonemize_with_dict(ref_text)
244
- input_text = phonemize_with_dict(input_text)
245
-
246
- codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes])
247
- prompt = (
248
- f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text} {input_text}"
249
- f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}"
250
- )
251
- output = self.backbone(
252
- prompt,
253
- max_tokens=self.max_context,
254
- temperature=1.0,
255
- top_k=50,
256
- stop=["<|SPEECH_GENERATION_END|>"],
257
- )
258
- output_str = output["choices"][0]["text"]
259
- return output_str
260
-
261
- def _infer_stream_ggml(self, ref_codes: torch.Tensor, ref_text: str, input_text: str) -> Generator[np.ndarray, None, None]:
262
- ref_text = phonemize_with_dict(ref_text)
263
- input_text = phonemize_with_dict(input_text)
264
-
265
- codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes])
266
- prompt = (
267
- f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text} {input_text}"
268
- f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}"
269
- )
270
-
271
- audio_cache: list[np.ndarray] = []
272
- token_cache: list[str] = [f"<|speech_{idx}|>" for idx in ref_codes]
273
- n_decoded_samples: int = 0
274
- n_decoded_tokens: int = len(ref_codes)
275
-
276
- for item in self.backbone(
277
- prompt,
278
- max_tokens=self.max_context,
279
- temperature=0.2,
280
- top_k=50,
281
- stop=["<|SPEECH_GENERATION_END|>"],
282
- stream=True
283
- ):
284
- output_str = item["choices"][0]["text"]
285
- token_cache.append(output_str)
286
-
287
- if len(token_cache[n_decoded_tokens:]) >= self.streaming_frames_per_chunk + self.streaming_lookforward:
288
-
289
- # decode chunk
290
- tokens_start = max(
291
- n_decoded_tokens
292
- - self.streaming_lookback
293
- - self.streaming_overlap_frames,
294
- 0
295
- )
296
- tokens_end = (
297
- n_decoded_tokens
298
- + self.streaming_frames_per_chunk
299
- + self.streaming_lookforward
300
- + self.streaming_overlap_frames
301
- )
302
- sample_start = (
303
- n_decoded_tokens - tokens_start
304
- ) * self.hop_length
305
- sample_end = (
306
- sample_start
307
- + (self.streaming_frames_per_chunk + 2 * self.streaming_overlap_frames) * self.hop_length
308
- )
309
- curr_codes = token_cache[tokens_start:tokens_end]
310
- recon = self._decode("".join(curr_codes))
311
- recon = recon[sample_start:sample_end]
312
- audio_cache.append(recon)
313
-
314
- # postprocess
315
- processed_recon = _linear_overlap_add(
316
- audio_cache, stride=self.streaming_stride_samples
317
- )
318
- new_samples_end = len(audio_cache) * self.streaming_stride_samples
319
- processed_recon = processed_recon[
320
- n_decoded_samples:new_samples_end
321
- ]
322
- n_decoded_samples = new_samples_end
323
- n_decoded_tokens += self.streaming_frames_per_chunk
324
- yield processed_recon
325
-
326
- # final decoding handled separately as non-constant chunk size
327
- remaining_tokens = len(token_cache) - n_decoded_tokens
328
- if len(token_cache) > n_decoded_tokens:
329
- tokens_start = max(
330
- len(token_cache)
331
- - (self.streaming_lookback + self.streaming_overlap_frames + remaining_tokens),
332
- 0
333
- )
334
- sample_start = (
335
- len(token_cache)
336
- - tokens_start
337
- - remaining_tokens
338
- - self.streaming_overlap_frames
339
- ) * self.hop_length
340
- curr_codes = token_cache[tokens_start:]
341
- recon = self._decode("".join(curr_codes))
342
- recon = recon[sample_start:]
343
- audio_cache.append(recon)
344
-
345
- processed_recon = _linear_overlap_add(audio_cache, stride=self.streaming_stride_samples)
346
- processed_recon = processed_recon[n_decoded_samples:]
347
- yield processed_recon