|
|
from pathlib import Path |
|
|
from typing import Generator |
|
|
import librosa |
|
|
import numpy as np |
|
|
import torch |
|
|
from neucodec import NeuCodec, DistillNeuCodec |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
from utils.phonemize_text import phonemize_text, phonemize_with_dict |
|
|
import re |
|
|
|
|
|
def _linear_overlap_add(frames: list[np.ndarray], stride: int) -> np.ndarray: |
|
|
|
|
|
assert len(frames) |
|
|
dtype = frames[0].dtype |
|
|
shape = frames[0].shape[:-1] |
|
|
|
|
|
total_size = 0 |
|
|
for i, frame in enumerate(frames): |
|
|
frame_end = stride * i + frame.shape[-1] |
|
|
total_size = max(total_size, frame_end) |
|
|
|
|
|
sum_weight = np.zeros(total_size, dtype=dtype) |
|
|
out = np.zeros(*shape, total_size, dtype=dtype) |
|
|
|
|
|
offset: int = 0 |
|
|
for frame in frames: |
|
|
frame_length = frame.shape[-1] |
|
|
t = np.linspace(0, 1, frame_length + 2, dtype=dtype)[1:-1] |
|
|
weight = np.abs(0.5 - (t - 0.5)) |
|
|
|
|
|
out[..., offset : offset + frame_length] += weight * frame |
|
|
sum_weight[offset : offset + frame_length] += weight |
|
|
offset += stride |
|
|
assert sum_weight.min() > 0 |
|
|
return out / sum_weight |
|
|
|
|
|
class VieNeuTTS: |
|
|
def __init__( |
|
|
self, |
|
|
backbone_repo="pnnbao-ump/VieNeu-TTS", |
|
|
backbone_device="cpu", |
|
|
codec_repo="neuphonic/neucodec", |
|
|
codec_device="cpu", |
|
|
): |
|
|
|
|
|
|
|
|
self.sample_rate = 24_000 |
|
|
self.max_context = 2048 |
|
|
self.hop_length = 480 |
|
|
self.streaming_overlap_frames = 1 |
|
|
self.streaming_frames_per_chunk = 25 |
|
|
self.streaming_lookforward = 5 |
|
|
self.streaming_lookback = 50 |
|
|
self.streaming_stride_samples = self.streaming_frames_per_chunk * self.hop_length |
|
|
|
|
|
|
|
|
self._is_quantized_model = False |
|
|
self._is_onnx_codec = False |
|
|
|
|
|
|
|
|
self.tokenizer = None |
|
|
|
|
|
|
|
|
self._load_backbone(backbone_repo, backbone_device) |
|
|
self._load_codec(codec_repo, codec_device) |
|
|
|
|
|
def _load_backbone(self, backbone_repo, backbone_device): |
|
|
print(f"Loading backbone from: {backbone_repo} on {backbone_device} ...") |
|
|
|
|
|
if backbone_repo.lower().endswith("gguf") or "gguf" in backbone_repo.lower(): |
|
|
try: |
|
|
from llama_cpp import Llama |
|
|
except ImportError as e: |
|
|
raise ImportError( |
|
|
"Failed to import `llama_cpp`. " |
|
|
"Please install it with:\n" |
|
|
" pip install llama-cpp-python" |
|
|
) from e |
|
|
self.backbone = Llama.from_pretrained( |
|
|
repo_id=backbone_repo, |
|
|
filename="*.gguf", |
|
|
verbose=False, |
|
|
n_gpu_layers=-1 if backbone_device == "gpu" else 0, |
|
|
n_ctx=self.max_context, |
|
|
mlock=True, |
|
|
flash_attn=True if backbone_device == "gpu" else False, |
|
|
) |
|
|
self._is_quantized_model = True |
|
|
|
|
|
else: |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(backbone_repo) |
|
|
self.backbone = AutoModelForCausalLM.from_pretrained(backbone_repo).to( |
|
|
torch.device(backbone_device) |
|
|
) |
|
|
|
|
|
def _load_codec(self, codec_repo, codec_device): |
|
|
print(f"Loading codec from: {codec_repo} on {codec_device} ...") |
|
|
match codec_repo: |
|
|
case "neuphonic/neucodec": |
|
|
self.codec = NeuCodec.from_pretrained(codec_repo) |
|
|
self.codec.eval().to(codec_device) |
|
|
case "neuphonic/distill-neucodec": |
|
|
self.codec = DistillNeuCodec.from_pretrained(codec_repo) |
|
|
self.codec.eval().to(codec_device) |
|
|
case "neuphonic/neucodec-onnx-decoder": |
|
|
if codec_device != "cpu": |
|
|
raise ValueError("Onnx decoder only currently runs on CPU.") |
|
|
try: |
|
|
from neucodec import NeuCodecOnnxDecoder |
|
|
except ImportError as e: |
|
|
raise ImportError( |
|
|
"Failed to import the onnx decoder." |
|
|
" Ensure you have onnxruntime installed as well as neucodec >= 0.0.4." |
|
|
) from e |
|
|
self.codec = NeuCodecOnnxDecoder.from_pretrained(codec_repo) |
|
|
self._is_onnx_codec = True |
|
|
case _: |
|
|
raise ValueError(f"Unsupported codec repository: {codec_repo}") |
|
|
|
|
|
def infer(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> np.ndarray: |
|
|
""" |
|
|
Perform inference to generate speech from text using the TTS model and reference audio. |
|
|
|
|
|
Args: |
|
|
text (str): Input text to be converted to speech. |
|
|
ref_codes (np.ndarray | torch.tensor): Encoded reference. |
|
|
ref_text (str): Reference text for reference audio. Defaults to None. |
|
|
Returns: |
|
|
np.ndarray: Generated speech waveform. |
|
|
""" |
|
|
|
|
|
|
|
|
if self._is_quantized_model: |
|
|
output_str = self._infer_ggml(ref_codes, ref_text, text) |
|
|
else: |
|
|
prompt_ids = self._apply_chat_template(ref_codes, ref_text, text) |
|
|
output_str = self._infer_torch(prompt_ids) |
|
|
|
|
|
|
|
|
wav = self._decode(output_str) |
|
|
|
|
|
return wav |
|
|
|
|
|
def infer_stream(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> Generator[np.ndarray, None, None]: |
|
|
""" |
|
|
Perform streaming inference to generate speech from text using the TTS model and reference audio. |
|
|
|
|
|
Args: |
|
|
text (str): Input text to be converted to speech. |
|
|
ref_codes (np.ndarray | torch.tensor): Encoded reference. |
|
|
ref_text (str): Reference text for reference audio. Defaults to None. |
|
|
Yields: |
|
|
np.ndarray: Generated speech waveform. |
|
|
""" |
|
|
|
|
|
if self._is_quantized_model: |
|
|
return self._infer_stream_ggml(ref_codes, ref_text, text) |
|
|
else: |
|
|
raise NotImplementedError("Streaming is not implemented for the torch backend!") |
|
|
|
|
|
def encode_reference(self, ref_audio_path: str | Path): |
|
|
wav, _ = librosa.load(ref_audio_path, sr=16000, mono=True) |
|
|
wav_tensor = torch.from_numpy(wav).float().unsqueeze(0).unsqueeze(0) |
|
|
with torch.no_grad(): |
|
|
ref_codes = self.codec.encode_code(audio_or_path=wav_tensor).squeeze(0).squeeze(0) |
|
|
return ref_codes |
|
|
|
|
|
def _decode(self, codes: str): |
|
|
"""Decode speech tokens to audio waveform.""" |
|
|
|
|
|
speech_ids = [int(num) for num in re.findall(r"<\|speech_(\d+)\|>", codes)] |
|
|
|
|
|
if len(speech_ids) == 0: |
|
|
raise ValueError( |
|
|
"No valid speech tokens found in the output. " |
|
|
"The model may not have generated proper speech tokens." |
|
|
) |
|
|
|
|
|
|
|
|
if self._is_onnx_codec: |
|
|
codes = np.array(speech_ids, dtype=np.int32)[np.newaxis, np.newaxis, :] |
|
|
recon = self.codec.decode_code(codes) |
|
|
|
|
|
else: |
|
|
with torch.no_grad(): |
|
|
codes = torch.tensor(speech_ids, dtype=torch.long)[None, None, :].to( |
|
|
self.codec.device |
|
|
) |
|
|
recon = self.codec.decode_code(codes).cpu().numpy() |
|
|
|
|
|
return recon[0, 0, :] |
|
|
|
|
|
def _apply_chat_template(self, ref_codes: list[int], ref_text: str, input_text: str) -> list[int]: |
|
|
input_text = phonemize_with_dict(ref_text) + " " + phonemize_with_dict(input_text) |
|
|
|
|
|
speech_replace = self.tokenizer.convert_tokens_to_ids("<|SPEECH_REPLACE|>") |
|
|
speech_gen_start = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_START|>") |
|
|
text_replace = self.tokenizer.convert_tokens_to_ids("<|TEXT_REPLACE|>") |
|
|
text_prompt_start = self.tokenizer.convert_tokens_to_ids("<|TEXT_PROMPT_START|>") |
|
|
text_prompt_end = self.tokenizer.convert_tokens_to_ids("<|TEXT_PROMPT_END|>") |
|
|
|
|
|
input_ids = self.tokenizer.encode(input_text, add_special_tokens=False) |
|
|
chat = """user: Convert the text to speech:<|TEXT_REPLACE|>\nassistant:<|SPEECH_REPLACE|>""" |
|
|
ids = self.tokenizer.encode(chat) |
|
|
|
|
|
text_replace_idx = ids.index(text_replace) |
|
|
ids = ( |
|
|
ids[:text_replace_idx] |
|
|
+ [text_prompt_start] |
|
|
+ input_ids |
|
|
+ [text_prompt_end] |
|
|
+ ids[text_replace_idx + 1 :] |
|
|
) |
|
|
|
|
|
speech_replace_idx = ids.index(speech_replace) |
|
|
codes_str = "".join([f"<|speech_{i}|>" for i in ref_codes]) |
|
|
codes = self.tokenizer.encode(codes_str, add_special_tokens=False) |
|
|
ids = ids[:speech_replace_idx] + [speech_gen_start] + list(codes) |
|
|
|
|
|
return ids |
|
|
|
|
|
def _infer_torch(self, prompt_ids: list[int]) -> str: |
|
|
prompt_tensor = torch.tensor(prompt_ids).unsqueeze(0).to(self.backbone.device) |
|
|
speech_end_id = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>") |
|
|
with torch.no_grad(): |
|
|
output_tokens = self.backbone.generate( |
|
|
prompt_tensor, |
|
|
max_length=self.max_context, |
|
|
eos_token_id=speech_end_id, |
|
|
do_sample=True, |
|
|
temperature=1, |
|
|
top_k=50, |
|
|
use_cache=True, |
|
|
min_new_tokens=50, |
|
|
) |
|
|
input_length = prompt_tensor.shape[-1] |
|
|
output_str = self.tokenizer.decode( |
|
|
output_tokens[0, input_length:].cpu().numpy().tolist(), add_special_tokens=False |
|
|
) |
|
|
return output_str |
|
|
|
|
|
def _infer_ggml(self, ref_codes: list[int], ref_text: str, input_text: str) -> str: |
|
|
ref_text = phonemize_with_dict(ref_text) |
|
|
input_text = phonemize_with_dict(input_text) |
|
|
|
|
|
codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes]) |
|
|
prompt = ( |
|
|
f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text} {input_text}" |
|
|
f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}" |
|
|
) |
|
|
output = self.backbone( |
|
|
prompt, |
|
|
max_tokens=self.max_context, |
|
|
temperature=1.0, |
|
|
top_k=50, |
|
|
stop=["<|SPEECH_GENERATION_END|>"], |
|
|
) |
|
|
output_str = output["choices"][0]["text"] |
|
|
return output_str |
|
|
|
|
|
def _infer_stream_ggml(self, ref_codes: torch.Tensor, ref_text: str, input_text: str) -> Generator[np.ndarray, None, None]: |
|
|
ref_text = phonemize_with_dict(ref_text) |
|
|
input_text = phonemize_with_dict(input_text) |
|
|
|
|
|
codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes]) |
|
|
prompt = ( |
|
|
f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text} {input_text}" |
|
|
f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}" |
|
|
) |
|
|
|
|
|
audio_cache: list[np.ndarray] = [] |
|
|
token_cache: list[str] = [f"<|speech_{idx}|>" for idx in ref_codes] |
|
|
n_decoded_samples: int = 0 |
|
|
n_decoded_tokens: int = len(ref_codes) |
|
|
|
|
|
for item in self.backbone( |
|
|
prompt, |
|
|
max_tokens=self.max_context, |
|
|
temperature=0.2, |
|
|
top_k=50, |
|
|
stop=["<|SPEECH_GENERATION_END|>"], |
|
|
stream=True |
|
|
): |
|
|
output_str = item["choices"][0]["text"] |
|
|
token_cache.append(output_str) |
|
|
|
|
|
if len(token_cache[n_decoded_tokens:]) >= self.streaming_frames_per_chunk + self.streaming_lookforward: |
|
|
|
|
|
|
|
|
tokens_start = max( |
|
|
n_decoded_tokens |
|
|
- self.streaming_lookback |
|
|
- self.streaming_overlap_frames, |
|
|
0 |
|
|
) |
|
|
tokens_end = ( |
|
|
n_decoded_tokens |
|
|
+ self.streaming_frames_per_chunk |
|
|
+ self.streaming_lookforward |
|
|
+ self.streaming_overlap_frames |
|
|
) |
|
|
sample_start = ( |
|
|
n_decoded_tokens - tokens_start |
|
|
) * self.hop_length |
|
|
sample_end = ( |
|
|
sample_start |
|
|
+ (self.streaming_frames_per_chunk + 2 * self.streaming_overlap_frames) * self.hop_length |
|
|
) |
|
|
curr_codes = token_cache[tokens_start:tokens_end] |
|
|
recon = self._decode("".join(curr_codes)) |
|
|
recon = recon[sample_start:sample_end] |
|
|
audio_cache.append(recon) |
|
|
|
|
|
|
|
|
processed_recon = _linear_overlap_add( |
|
|
audio_cache, stride=self.streaming_stride_samples |
|
|
) |
|
|
new_samples_end = len(audio_cache) * self.streaming_stride_samples |
|
|
processed_recon = processed_recon[ |
|
|
n_decoded_samples:new_samples_end |
|
|
] |
|
|
n_decoded_samples = new_samples_end |
|
|
n_decoded_tokens += self.streaming_frames_per_chunk |
|
|
yield processed_recon |
|
|
|
|
|
|
|
|
remaining_tokens = len(token_cache) - n_decoded_tokens |
|
|
if len(token_cache) > n_decoded_tokens: |
|
|
tokens_start = max( |
|
|
len(token_cache) |
|
|
- (self.streaming_lookback + self.streaming_overlap_frames + remaining_tokens), |
|
|
0 |
|
|
) |
|
|
sample_start = ( |
|
|
len(token_cache) |
|
|
- tokens_start |
|
|
- remaining_tokens |
|
|
- self.streaming_overlap_frames |
|
|
) * self.hop_length |
|
|
curr_codes = token_cache[tokens_start:] |
|
|
recon = self._decode("".join(curr_codes)) |
|
|
recon = recon[sample_start:] |
|
|
audio_cache.append(recon) |
|
|
|
|
|
processed_recon = _linear_overlap_add(audio_cache, stride=self.streaming_stride_samples) |
|
|
processed_recon = processed_recon[n_decoded_samples:] |
|
|
yield processed_recon |
|
|
|