Spaces:
Running
on
Zero
Running
on
Zero
Upload 2 files
Browse files- vieneu_tts/__init__.py +4 -0
- vieneu_tts/vieneu_tts.py +385 -0
vieneu_tts/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .vieneu_tts import VieNeuTTS
|
| 2 |
+
|
| 3 |
+
__all__ = ["VieNeuTTS"]
|
| 4 |
+
|
vieneu_tts/vieneu_tts.py
ADDED
|
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from typing import Generator
|
| 3 |
+
import librosa
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from neucodec import NeuCodec, DistillNeuCodec
|
| 7 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 8 |
+
from utils.phonemize_text import phonemize_text, phonemize_with_dict
|
| 9 |
+
import re
|
| 10 |
+
|
| 11 |
+
def _linear_overlap_add(frames: list[np.ndarray], stride: int) -> np.ndarray:
|
| 12 |
+
# original impl --> https://github.com/facebookresearch/encodec/blob/main/encodec/utils.py
|
| 13 |
+
assert len(frames)
|
| 14 |
+
dtype = frames[0].dtype
|
| 15 |
+
shape = frames[0].shape[:-1]
|
| 16 |
+
|
| 17 |
+
total_size = 0
|
| 18 |
+
for i, frame in enumerate(frames):
|
| 19 |
+
frame_end = stride * i + frame.shape[-1]
|
| 20 |
+
total_size = max(total_size, frame_end)
|
| 21 |
+
|
| 22 |
+
sum_weight = np.zeros(total_size, dtype=dtype)
|
| 23 |
+
out = np.zeros(*shape, total_size, dtype=dtype)
|
| 24 |
+
|
| 25 |
+
offset: int = 0
|
| 26 |
+
for frame in frames:
|
| 27 |
+
frame_length = frame.shape[-1]
|
| 28 |
+
t = np.linspace(0, 1, frame_length + 2, dtype=dtype)[1:-1]
|
| 29 |
+
weight = np.abs(0.5 - (t - 0.5))
|
| 30 |
+
|
| 31 |
+
out[..., offset : offset + frame_length] += weight * frame
|
| 32 |
+
sum_weight[offset : offset + frame_length] += weight
|
| 33 |
+
offset += stride
|
| 34 |
+
assert sum_weight.min() > 0
|
| 35 |
+
return out / sum_weight
|
| 36 |
+
|
| 37 |
+
class VieNeuTTS:
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
backbone_repo="pnnbao-ump/VieNeu-TTS",
|
| 41 |
+
backbone_device="cpu",
|
| 42 |
+
codec_repo="neuphonic/neucodec",
|
| 43 |
+
codec_device="cpu",
|
| 44 |
+
):
|
| 45 |
+
|
| 46 |
+
# Constants
|
| 47 |
+
self.sample_rate = 24_000
|
| 48 |
+
self.max_context = 2048
|
| 49 |
+
self.hop_length = 480
|
| 50 |
+
self.streaming_overlap_frames = 1
|
| 51 |
+
self.streaming_frames_per_chunk = 25
|
| 52 |
+
self.streaming_lookforward = 5
|
| 53 |
+
self.streaming_lookback = 50
|
| 54 |
+
self.streaming_stride_samples = self.streaming_frames_per_chunk * self.hop_length
|
| 55 |
+
|
| 56 |
+
# ggml & onnx flags
|
| 57 |
+
self._is_quantized_model = False
|
| 58 |
+
self._is_onnx_codec = False
|
| 59 |
+
|
| 60 |
+
# HF tokenizer
|
| 61 |
+
self.tokenizer = None
|
| 62 |
+
|
| 63 |
+
# Load models
|
| 64 |
+
self._load_backbone(backbone_repo, backbone_device)
|
| 65 |
+
self._load_codec(codec_repo, codec_device)
|
| 66 |
+
|
| 67 |
+
def _load_backbone(self, backbone_repo, backbone_device):
|
| 68 |
+
print(f"Loading backbone from: {backbone_repo} on {backbone_device} ...")
|
| 69 |
+
|
| 70 |
+
if backbone_repo.lower().endswith("gguf") or "gguf" in backbone_repo.lower():
|
| 71 |
+
try:
|
| 72 |
+
from llama_cpp import Llama
|
| 73 |
+
except ImportError as e:
|
| 74 |
+
raise ImportError(
|
| 75 |
+
"Failed to import `llama_cpp`. "
|
| 76 |
+
"Please install it with:\n"
|
| 77 |
+
" pip install llama-cpp-python"
|
| 78 |
+
) from e
|
| 79 |
+
self.backbone = Llama.from_pretrained(
|
| 80 |
+
repo_id=backbone_repo,
|
| 81 |
+
filename="*.gguf",
|
| 82 |
+
verbose=False,
|
| 83 |
+
n_gpu_layers=-1 if backbone_device == "gpu" else 0,
|
| 84 |
+
n_ctx=self.max_context,
|
| 85 |
+
mlock=True,
|
| 86 |
+
flash_attn=True if backbone_device == "gpu" else False,
|
| 87 |
+
)
|
| 88 |
+
self._is_quantized_model = True
|
| 89 |
+
|
| 90 |
+
else:
|
| 91 |
+
self.tokenizer = AutoTokenizer.from_pretrained(backbone_repo)
|
| 92 |
+
print(f" Loading model to device: {backbone_device}")
|
| 93 |
+
|
| 94 |
+
print(f" 📦 Loading with FP32 (stable mode)")
|
| 95 |
+
self.backbone = AutoModelForCausalLM.from_pretrained(backbone_repo)
|
| 96 |
+
|
| 97 |
+
print(f" Model loaded, moving to {backbone_device}...")
|
| 98 |
+
self.backbone = self.backbone.to(torch.device(backbone_device))
|
| 99 |
+
print(f" ✓ Backbone on device: {next(self.backbone.parameters()).device}")
|
| 100 |
+
print(f" ✓ Backbone dtype: {next(self.backbone.parameters()).dtype}")
|
| 101 |
+
|
| 102 |
+
def _load_codec(self, codec_repo, codec_device):
|
| 103 |
+
print(f"Loading codec from: {codec_repo} on {codec_device} ...")
|
| 104 |
+
match codec_repo:
|
| 105 |
+
case "neuphonic/neucodec":
|
| 106 |
+
self.codec = NeuCodec.from_pretrained(codec_repo)
|
| 107 |
+
|
| 108 |
+
# Keep codec in FP32 for compatibility with feature_extractor
|
| 109 |
+
# Only backbone uses FP16
|
| 110 |
+
print(f" 📦 Keeping codec in FP32 (compatibility)")
|
| 111 |
+
|
| 112 |
+
self.codec.eval().to(codec_device)
|
| 113 |
+
print(f" ✓ Codec on device: {next(self.codec.parameters()).device}")
|
| 114 |
+
print(f" ✓ Codec dtype: {next(self.codec.parameters()).dtype}")
|
| 115 |
+
case "neuphonic/distill-neucodec":
|
| 116 |
+
self.codec = DistillNeuCodec.from_pretrained(codec_repo)
|
| 117 |
+
|
| 118 |
+
# Keep distill-codec in FP32 for compatibility
|
| 119 |
+
print(f" 📦 Keeping distill-codec in FP32 (compatibility)")
|
| 120 |
+
|
| 121 |
+
self.codec.eval().to(codec_device)
|
| 122 |
+
print(f" ✓ Distill-Codec on device: {next(self.codec.parameters()).device}")
|
| 123 |
+
print(f" ✓ Distill-Codec dtype: {next(self.codec.parameters()).dtype}")
|
| 124 |
+
case "neuphonic/neucodec-onnx-decoder":
|
| 125 |
+
if codec_device != "cpu":
|
| 126 |
+
raise ValueError("Onnx decoder only currently runs on CPU.")
|
| 127 |
+
try:
|
| 128 |
+
from neucodec import NeuCodecOnnxDecoder
|
| 129 |
+
except ImportError as e:
|
| 130 |
+
raise ImportError(
|
| 131 |
+
"Failed to import the onnx decoder."
|
| 132 |
+
" Ensure you have onnxruntime installed as well as neucodec >= 0.0.4."
|
| 133 |
+
) from e
|
| 134 |
+
self.codec = NeuCodecOnnxDecoder.from_pretrained(codec_repo)
|
| 135 |
+
self._is_onnx_codec = True
|
| 136 |
+
case _:
|
| 137 |
+
raise ValueError(f"Unsupported codec repository: {codec_repo}")
|
| 138 |
+
|
| 139 |
+
def infer(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> np.ndarray:
|
| 140 |
+
"""
|
| 141 |
+
Perform inference to generate speech from text using the TTS model and reference audio.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
text (str): Input text to be converted to speech.
|
| 145 |
+
ref_codes (np.ndarray | torch.tensor): Encoded reference.
|
| 146 |
+
ref_text (str): Reference text for reference audio. Defaults to None.
|
| 147 |
+
Returns:
|
| 148 |
+
np.ndarray: Generated speech waveform.
|
| 149 |
+
"""
|
| 150 |
+
|
| 151 |
+
# Generate tokens
|
| 152 |
+
if self._is_quantized_model:
|
| 153 |
+
output_str = self._infer_ggml(ref_codes, ref_text, text)
|
| 154 |
+
else:
|
| 155 |
+
prompt_ids = self._apply_chat_template(ref_codes, ref_text, text)
|
| 156 |
+
output_str = self._infer_torch(prompt_ids)
|
| 157 |
+
|
| 158 |
+
# Decode
|
| 159 |
+
wav = self._decode(output_str)
|
| 160 |
+
|
| 161 |
+
return wav
|
| 162 |
+
|
| 163 |
+
def infer_stream(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> Generator[np.ndarray, None, None]:
|
| 164 |
+
"""
|
| 165 |
+
Perform streaming inference to generate speech from text using the TTS model and reference audio.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
text (str): Input text to be converted to speech.
|
| 169 |
+
ref_codes (np.ndarray | torch.tensor): Encoded reference.
|
| 170 |
+
ref_text (str): Reference text for reference audio. Defaults to None.
|
| 171 |
+
Yields:
|
| 172 |
+
np.ndarray: Generated speech waveform.
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
if self._is_quantized_model:
|
| 176 |
+
return self._infer_stream_ggml(ref_codes, ref_text, text)
|
| 177 |
+
else:
|
| 178 |
+
raise NotImplementedError("Streaming is not implemented for the torch backend!")
|
| 179 |
+
|
| 180 |
+
def encode_reference(self, ref_audio_path: str | Path):
|
| 181 |
+
wav, _ = librosa.load(ref_audio_path, sr=16000, mono=True)
|
| 182 |
+
wav_tensor = torch.from_numpy(wav).float().unsqueeze(0).unsqueeze(0) # [1, 1, T]
|
| 183 |
+
|
| 184 |
+
# NeuCodec expects CPU tensor for encode_code
|
| 185 |
+
wav_tensor_cpu = wav_tensor.cpu().float()
|
| 186 |
+
|
| 187 |
+
with torch.no_grad():
|
| 188 |
+
ref_codes = self.codec.encode_code(audio_or_path=wav_tensor_cpu).squeeze(0).squeeze(0)
|
| 189 |
+
|
| 190 |
+
# Ensure result is on CPU for caching
|
| 191 |
+
if ref_codes.device.type != 'cpu':
|
| 192 |
+
ref_codes = ref_codes.cpu()
|
| 193 |
+
|
| 194 |
+
return ref_codes
|
| 195 |
+
|
| 196 |
+
def _decode(self, codes: str):
|
| 197 |
+
"""Decode speech tokens to audio waveform."""
|
| 198 |
+
# Extract speech token IDs using regex
|
| 199 |
+
speech_ids = [int(num) for num in re.findall(r"<\|speech_(\d+)\|>", codes)]
|
| 200 |
+
|
| 201 |
+
if len(speech_ids) == 0:
|
| 202 |
+
raise ValueError(
|
| 203 |
+
"No valid speech tokens found in the output. "
|
| 204 |
+
"The model may not have generated proper speech tokens."
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# Onnx decode
|
| 208 |
+
if self._is_onnx_codec:
|
| 209 |
+
codes = np.array(speech_ids, dtype=np.int32)[np.newaxis, np.newaxis, :]
|
| 210 |
+
recon = self.codec.decode_code(codes)
|
| 211 |
+
# Torch decode
|
| 212 |
+
else:
|
| 213 |
+
with torch.no_grad():
|
| 214 |
+
codes = torch.tensor(speech_ids, dtype=torch.long)[None, None, :].to(
|
| 215 |
+
self.codec.device
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# Codec is kept in FP32, no need for autocast
|
| 219 |
+
recon = self.codec.decode_code(codes).cpu().numpy()
|
| 220 |
+
|
| 221 |
+
return recon[0, 0, :]
|
| 222 |
+
|
| 223 |
+
def _apply_chat_template(self, ref_codes: list[int], ref_text: str, input_text: str) -> list[int]:
|
| 224 |
+
# Convert ref_codes to list if it's a tensor
|
| 225 |
+
if hasattr(ref_codes, 'cpu'):
|
| 226 |
+
ref_codes = ref_codes.cpu().numpy().tolist()
|
| 227 |
+
elif hasattr(ref_codes, 'tolist'):
|
| 228 |
+
ref_codes = ref_codes.tolist()
|
| 229 |
+
|
| 230 |
+
input_text = phonemize_with_dict(ref_text) + " " + phonemize_with_dict(input_text)
|
| 231 |
+
|
| 232 |
+
speech_replace = self.tokenizer.convert_tokens_to_ids("<|SPEECH_REPLACE|>")
|
| 233 |
+
speech_gen_start = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_START|>")
|
| 234 |
+
text_replace = self.tokenizer.convert_tokens_to_ids("<|TEXT_REPLACE|>")
|
| 235 |
+
text_prompt_start = self.tokenizer.convert_tokens_to_ids("<|TEXT_PROMPT_START|>")
|
| 236 |
+
text_prompt_end = self.tokenizer.convert_tokens_to_ids("<|TEXT_PROMPT_END|>")
|
| 237 |
+
|
| 238 |
+
input_ids = self.tokenizer.encode(input_text, add_special_tokens=False)
|
| 239 |
+
chat = """user: Convert the text to speech:<|TEXT_REPLACE|>\nassistant:<|SPEECH_REPLACE|>"""
|
| 240 |
+
ids = self.tokenizer.encode(chat)
|
| 241 |
+
|
| 242 |
+
text_replace_idx = ids.index(text_replace)
|
| 243 |
+
ids = (
|
| 244 |
+
ids[:text_replace_idx]
|
| 245 |
+
+ [text_prompt_start]
|
| 246 |
+
+ input_ids
|
| 247 |
+
+ [text_prompt_end]
|
| 248 |
+
+ ids[text_replace_idx + 1 :] # noqa
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
speech_replace_idx = ids.index(speech_replace)
|
| 252 |
+
codes_str = "".join([f"<|speech_{i}|>" for i in ref_codes])
|
| 253 |
+
codes = self.tokenizer.encode(codes_str, add_special_tokens=False)
|
| 254 |
+
ids = ids[:speech_replace_idx] + [speech_gen_start] + list(codes)
|
| 255 |
+
|
| 256 |
+
return ids
|
| 257 |
+
|
| 258 |
+
def _infer_torch(self, prompt_ids: list[int]) -> str:
|
| 259 |
+
prompt_tensor = torch.tensor(prompt_ids).unsqueeze(0).to(self.backbone.device)
|
| 260 |
+
speech_end_id = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>")
|
| 261 |
+
|
| 262 |
+
with torch.no_grad():
|
| 263 |
+
output_tokens = self.backbone.generate(
|
| 264 |
+
prompt_tensor,
|
| 265 |
+
max_length=self.max_context,
|
| 266 |
+
eos_token_id=speech_end_id,
|
| 267 |
+
do_sample=True,
|
| 268 |
+
temperature=1.0,
|
| 269 |
+
top_k=50,
|
| 270 |
+
use_cache=True,
|
| 271 |
+
min_new_tokens=50,
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
input_length = prompt_tensor.shape[-1]
|
| 275 |
+
output_str = self.tokenizer.decode(
|
| 276 |
+
output_tokens[0, input_length:].cpu().numpy().tolist(), add_special_tokens=False
|
| 277 |
+
)
|
| 278 |
+
return output_str
|
| 279 |
+
|
| 280 |
+
def _infer_ggml(self, ref_codes: list[int], ref_text: str, input_text: str) -> str:
|
| 281 |
+
ref_text = phonemize_with_dict(ref_text)
|
| 282 |
+
input_text = phonemize_with_dict(input_text)
|
| 283 |
+
|
| 284 |
+
codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes])
|
| 285 |
+
prompt = (
|
| 286 |
+
f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text} {input_text}"
|
| 287 |
+
f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}"
|
| 288 |
+
)
|
| 289 |
+
output = self.backbone(
|
| 290 |
+
prompt,
|
| 291 |
+
max_tokens=self.max_context,
|
| 292 |
+
temperature=1.0,
|
| 293 |
+
top_k=50,
|
| 294 |
+
stop=["<|SPEECH_GENERATION_END|>"],
|
| 295 |
+
)
|
| 296 |
+
output_str = output["choices"][0]["text"]
|
| 297 |
+
return output_str
|
| 298 |
+
|
| 299 |
+
def _infer_stream_ggml(self, ref_codes: torch.Tensor, ref_text: str, input_text: str) -> Generator[np.ndarray, None, None]:
|
| 300 |
+
ref_text = phonemize_with_dict(ref_text)
|
| 301 |
+
input_text = phonemize_with_dict(input_text)
|
| 302 |
+
|
| 303 |
+
codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes])
|
| 304 |
+
prompt = (
|
| 305 |
+
f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text} {input_text}"
|
| 306 |
+
f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}"
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
audio_cache: list[np.ndarray] = []
|
| 310 |
+
token_cache: list[str] = [f"<|speech_{idx}|>" for idx in ref_codes]
|
| 311 |
+
n_decoded_samples: int = 0
|
| 312 |
+
n_decoded_tokens: int = len(ref_codes)
|
| 313 |
+
|
| 314 |
+
for item in self.backbone(
|
| 315 |
+
prompt,
|
| 316 |
+
max_tokens=self.max_context,
|
| 317 |
+
temperature=0.2,
|
| 318 |
+
top_k=50,
|
| 319 |
+
stop=["<|SPEECH_GENERATION_END|>"],
|
| 320 |
+
stream=True
|
| 321 |
+
):
|
| 322 |
+
output_str = item["choices"][0]["text"]
|
| 323 |
+
token_cache.append(output_str)
|
| 324 |
+
|
| 325 |
+
if len(token_cache[n_decoded_tokens:]) >= self.streaming_frames_per_chunk + self.streaming_lookforward:
|
| 326 |
+
|
| 327 |
+
# decode chunk
|
| 328 |
+
tokens_start = max(
|
| 329 |
+
n_decoded_tokens
|
| 330 |
+
- self.streaming_lookback
|
| 331 |
+
- self.streaming_overlap_frames,
|
| 332 |
+
0
|
| 333 |
+
)
|
| 334 |
+
tokens_end = (
|
| 335 |
+
n_decoded_tokens
|
| 336 |
+
+ self.streaming_frames_per_chunk
|
| 337 |
+
+ self.streaming_lookforward
|
| 338 |
+
+ self.streaming_overlap_frames
|
| 339 |
+
)
|
| 340 |
+
sample_start = (
|
| 341 |
+
n_decoded_tokens - tokens_start
|
| 342 |
+
) * self.hop_length
|
| 343 |
+
sample_end = (
|
| 344 |
+
sample_start
|
| 345 |
+
+ (self.streaming_frames_per_chunk + 2 * self.streaming_overlap_frames) * self.hop_length
|
| 346 |
+
)
|
| 347 |
+
curr_codes = token_cache[tokens_start:tokens_end]
|
| 348 |
+
recon = self._decode("".join(curr_codes))
|
| 349 |
+
recon = recon[sample_start:sample_end]
|
| 350 |
+
audio_cache.append(recon)
|
| 351 |
+
|
| 352 |
+
# postprocess
|
| 353 |
+
processed_recon = _linear_overlap_add(
|
| 354 |
+
audio_cache, stride=self.streaming_stride_samples
|
| 355 |
+
)
|
| 356 |
+
new_samples_end = len(audio_cache) * self.streaming_stride_samples
|
| 357 |
+
processed_recon = processed_recon[
|
| 358 |
+
n_decoded_samples:new_samples_end
|
| 359 |
+
]
|
| 360 |
+
n_decoded_samples = new_samples_end
|
| 361 |
+
n_decoded_tokens += self.streaming_frames_per_chunk
|
| 362 |
+
yield processed_recon
|
| 363 |
+
|
| 364 |
+
# final decoding handled separately as non-constant chunk size
|
| 365 |
+
remaining_tokens = len(token_cache) - n_decoded_tokens
|
| 366 |
+
if len(token_cache) > n_decoded_tokens:
|
| 367 |
+
tokens_start = max(
|
| 368 |
+
len(token_cache)
|
| 369 |
+
- (self.streaming_lookback + self.streaming_overlap_frames + remaining_tokens),
|
| 370 |
+
0
|
| 371 |
+
)
|
| 372 |
+
sample_start = (
|
| 373 |
+
len(token_cache)
|
| 374 |
+
- tokens_start
|
| 375 |
+
- remaining_tokens
|
| 376 |
+
- self.streaming_overlap_frames
|
| 377 |
+
) * self.hop_length
|
| 378 |
+
curr_codes = token_cache[tokens_start:]
|
| 379 |
+
recon = self._decode("".join(curr_codes))
|
| 380 |
+
recon = recon[sample_start:]
|
| 381 |
+
audio_cache.append(recon)
|
| 382 |
+
|
| 383 |
+
processed_recon = _linear_overlap_add(audio_cache, stride=self.streaming_stride_samples)
|
| 384 |
+
processed_recon = processed_recon[n_decoded_samples:]
|
| 385 |
+
yield processed_recon
|