CLONE / vieneu_tts.py
pnnbao-ump's picture
Update vieneu_tts.py
1c22460 verified
from pathlib import Path
from typing import Generator
import librosa
import numpy as np
import torch
from neucodec import NeuCodec, DistillNeuCodec
from transformers import AutoTokenizer, AutoModelForCausalLM
from utils.phonemize_text import phonemize_text, phonemize_with_dict
import re
def _linear_overlap_add(frames: list[np.ndarray], stride: int) -> np.ndarray:
# original impl --> https://github.com/facebookresearch/encodec/blob/main/encodec/utils.py
assert len(frames)
dtype = frames[0].dtype
shape = frames[0].shape[:-1]
total_size = 0
for i, frame in enumerate(frames):
frame_end = stride * i + frame.shape[-1]
total_size = max(total_size, frame_end)
sum_weight = np.zeros(total_size, dtype=dtype)
out = np.zeros(*shape, total_size, dtype=dtype)
offset: int = 0
for frame in frames:
frame_length = frame.shape[-1]
t = np.linspace(0, 1, frame_length + 2, dtype=dtype)[1:-1]
weight = np.abs(0.5 - (t - 0.5))
out[..., offset : offset + frame_length] += weight * frame
sum_weight[offset : offset + frame_length] += weight
offset += stride
assert sum_weight.min() > 0
return out / sum_weight
class VieNeuTTS:
def __init__(
self,
backbone_repo="pnnbao-ump/VieNeu-TTS",
backbone_device="cpu",
codec_repo="neuphonic/neucodec",
codec_device="cpu",
):
# Constants
self.sample_rate = 24_000
self.max_context = 2048
self.hop_length = 480
self.streaming_overlap_frames = 1
self.streaming_frames_per_chunk = 25
self.streaming_lookforward = 5
self.streaming_lookback = 50
self.streaming_stride_samples = self.streaming_frames_per_chunk * self.hop_length
# ggml & onnx flags
self._is_quantized_model = False
self._is_onnx_codec = False
# HF tokenizer
self.tokenizer = None
# Load models
self._load_backbone(backbone_repo, backbone_device)
self._load_codec(codec_repo, codec_device)
def _load_backbone(self, backbone_repo, backbone_device):
print(f"Loading backbone from: {backbone_repo} on {backbone_device} ...")
if backbone_repo.lower().endswith("gguf") or "gguf" in backbone_repo.lower():
try:
from llama_cpp import Llama
except ImportError as e:
raise ImportError(
"Failed to import `llama_cpp`. "
"Please install it with:\n"
" pip install llama-cpp-python"
) from e
self.backbone = Llama.from_pretrained(
repo_id=backbone_repo,
filename="*.gguf",
verbose=False,
n_gpu_layers=-1 if backbone_device == "gpu" else 0,
n_ctx=self.max_context,
mlock=True,
flash_attn=True if backbone_device == "gpu" else False,
)
self._is_quantized_model = True
else:
self.tokenizer = AutoTokenizer.from_pretrained(backbone_repo)
self.backbone = AutoModelForCausalLM.from_pretrained(backbone_repo).to(
torch.device(backbone_device)
)
def _load_codec(self, codec_repo, codec_device):
print(f"Loading codec from: {codec_repo} on {codec_device} ...")
match codec_repo:
case "neuphonic/neucodec":
self.codec = NeuCodec.from_pretrained(codec_repo)
self.codec.eval().to(codec_device)
case "neuphonic/distill-neucodec":
self.codec = DistillNeuCodec.from_pretrained(codec_repo)
self.codec.eval().to(codec_device)
case "neuphonic/neucodec-onnx-decoder":
if codec_device != "cpu":
raise ValueError("Onnx decoder only currently runs on CPU.")
try:
from neucodec import NeuCodecOnnxDecoder
except ImportError as e:
raise ImportError(
"Failed to import the onnx decoder."
" Ensure you have onnxruntime installed as well as neucodec >= 0.0.4."
) from e
self.codec = NeuCodecOnnxDecoder.from_pretrained(codec_repo)
self._is_onnx_codec = True
case _:
raise ValueError(f"Unsupported codec repository: {codec_repo}")
def infer(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> np.ndarray:
"""
Perform inference to generate speech from text using the TTS model and reference audio.
Args:
text (str): Input text to be converted to speech.
ref_codes (np.ndarray | torch.tensor): Encoded reference.
ref_text (str): Reference text for reference audio. Defaults to None.
Returns:
np.ndarray: Generated speech waveform.
"""
# Generate tokens
if self._is_quantized_model:
output_str = self._infer_ggml(ref_codes, ref_text, text)
else:
prompt_ids = self._apply_chat_template(ref_codes, ref_text, text)
output_str = self._infer_torch(prompt_ids)
# Decode
wav = self._decode(output_str)
return wav
def infer_stream(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> Generator[np.ndarray, None, None]:
"""
Perform streaming inference to generate speech from text using the TTS model and reference audio.
Args:
text (str): Input text to be converted to speech.
ref_codes (np.ndarray | torch.tensor): Encoded reference.
ref_text (str): Reference text for reference audio. Defaults to None.
Yields:
np.ndarray: Generated speech waveform.
"""
if self._is_quantized_model:
return self._infer_stream_ggml(ref_codes, ref_text, text)
else:
raise NotImplementedError("Streaming is not implemented for the torch backend!")
def encode_reference(self, ref_audio_path: str | Path):
wav, _ = librosa.load(ref_audio_path, sr=16000, mono=True)
wav_tensor = torch.from_numpy(wav).float().unsqueeze(0).unsqueeze(0) # [1, 1, T]
with torch.no_grad():
ref_codes = self.codec.encode_code(audio_or_path=wav_tensor).squeeze(0).squeeze(0)
return ref_codes
def _decode(self, codes: str):
"""Decode speech tokens to audio waveform."""
# Extract speech token IDs using regex
speech_ids = [int(num) for num in re.findall(r"<\|speech_(\d+)\|>", codes)]
if len(speech_ids) == 0:
raise ValueError(
"No valid speech tokens found in the output. "
"The model may not have generated proper speech tokens."
)
# Onnx decode
if self._is_onnx_codec:
codes = np.array(speech_ids, dtype=np.int32)[np.newaxis, np.newaxis, :]
recon = self.codec.decode_code(codes)
# Torch decode
else:
with torch.no_grad():
codes = torch.tensor(speech_ids, dtype=torch.long)[None, None, :].to(
self.codec.device
)
recon = self.codec.decode_code(codes).cpu().numpy()
return recon[0, 0, :]
def _apply_chat_template(self, ref_codes: list[int], ref_text: str, input_text: str) -> list[int]:
input_text = phonemize_with_dict(ref_text) + " " + phonemize_with_dict(input_text)
speech_replace = self.tokenizer.convert_tokens_to_ids("<|SPEECH_REPLACE|>")
speech_gen_start = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_START|>")
text_replace = self.tokenizer.convert_tokens_to_ids("<|TEXT_REPLACE|>")
text_prompt_start = self.tokenizer.convert_tokens_to_ids("<|TEXT_PROMPT_START|>")
text_prompt_end = self.tokenizer.convert_tokens_to_ids("<|TEXT_PROMPT_END|>")
input_ids = self.tokenizer.encode(input_text, add_special_tokens=False)
chat = """user: Convert the text to speech:<|TEXT_REPLACE|>\nassistant:<|SPEECH_REPLACE|>"""
ids = self.tokenizer.encode(chat)
text_replace_idx = ids.index(text_replace)
ids = (
ids[:text_replace_idx]
+ [text_prompt_start]
+ input_ids
+ [text_prompt_end]
+ ids[text_replace_idx + 1 :] # noqa
)
speech_replace_idx = ids.index(speech_replace)
codes_str = "".join([f"<|speech_{i}|>" for i in ref_codes])
codes = self.tokenizer.encode(codes_str, add_special_tokens=False)
ids = ids[:speech_replace_idx] + [speech_gen_start] + list(codes)
return ids
def _infer_torch(self, prompt_ids: list[int]) -> str:
prompt_tensor = torch.tensor(prompt_ids).unsqueeze(0).to(self.backbone.device)
speech_end_id = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>")
with torch.no_grad():
output_tokens = self.backbone.generate(
prompt_tensor,
max_length=self.max_context,
eos_token_id=speech_end_id,
do_sample=True,
temperature=1,
top_k=50,
use_cache=True,
min_new_tokens=50,
)
input_length = prompt_tensor.shape[-1]
output_str = self.tokenizer.decode(
output_tokens[0, input_length:].cpu().numpy().tolist(), add_special_tokens=False
)
return output_str
def _infer_ggml(self, ref_codes: list[int], ref_text: str, input_text: str) -> str:
ref_text = phonemize_with_dict(ref_text)
input_text = phonemize_with_dict(input_text)
codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes])
prompt = (
f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text} {input_text}"
f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}"
)
output = self.backbone(
prompt,
max_tokens=self.max_context,
temperature=1.0,
top_k=50,
stop=["<|SPEECH_GENERATION_END|>"],
)
output_str = output["choices"][0]["text"]
return output_str
def _infer_stream_ggml(self, ref_codes: torch.Tensor, ref_text: str, input_text: str) -> Generator[np.ndarray, None, None]:
ref_text = phonemize_with_dict(ref_text)
input_text = phonemize_with_dict(input_text)
codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes])
prompt = (
f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text} {input_text}"
f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}"
)
audio_cache: list[np.ndarray] = []
token_cache: list[str] = [f"<|speech_{idx}|>" for idx in ref_codes]
n_decoded_samples: int = 0
n_decoded_tokens: int = len(ref_codes)
for item in self.backbone(
prompt,
max_tokens=self.max_context,
temperature=0.2,
top_k=50,
stop=["<|SPEECH_GENERATION_END|>"],
stream=True
):
output_str = item["choices"][0]["text"]
token_cache.append(output_str)
if len(token_cache[n_decoded_tokens:]) >= self.streaming_frames_per_chunk + self.streaming_lookforward:
# decode chunk
tokens_start = max(
n_decoded_tokens
- self.streaming_lookback
- self.streaming_overlap_frames,
0
)
tokens_end = (
n_decoded_tokens
+ self.streaming_frames_per_chunk
+ self.streaming_lookforward
+ self.streaming_overlap_frames
)
sample_start = (
n_decoded_tokens - tokens_start
) * self.hop_length
sample_end = (
sample_start
+ (self.streaming_frames_per_chunk + 2 * self.streaming_overlap_frames) * self.hop_length
)
curr_codes = token_cache[tokens_start:tokens_end]
recon = self._decode("".join(curr_codes))
recon = recon[sample_start:sample_end]
audio_cache.append(recon)
# postprocess
processed_recon = _linear_overlap_add(
audio_cache, stride=self.streaming_stride_samples
)
new_samples_end = len(audio_cache) * self.streaming_stride_samples
processed_recon = processed_recon[
n_decoded_samples:new_samples_end
]
n_decoded_samples = new_samples_end
n_decoded_tokens += self.streaming_frames_per_chunk
yield processed_recon
# final decoding handled separately as non-constant chunk size
remaining_tokens = len(token_cache) - n_decoded_tokens
if len(token_cache) > n_decoded_tokens:
tokens_start = max(
len(token_cache)
- (self.streaming_lookback + self.streaming_overlap_frames + remaining_tokens),
0
)
sample_start = (
len(token_cache)
- tokens_start
- remaining_tokens
- self.streaming_overlap_frames
) * self.hop_length
curr_codes = token_cache[tokens_start:]
recon = self._decode("".join(curr_codes))
recon = recon[sample_start:]
audio_cache.append(recon)
processed_recon = _linear_overlap_add(audio_cache, stride=self.streaming_stride_samples)
processed_recon = processed_recon[n_decoded_samples:]
yield processed_recon