File size: 6,560 Bytes
a257816 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import os
import torch
import torchaudio
import whisper
import onnxruntime
import numpy as np
import torchaudio.compliance.kaldi as kaldi
from typing import Callable, List, Union
from functools import partial
from loguru import logger
from VietTTS.utils.frontend_utils import split_text, normalize_text, mel_spectrogram
from VietTTS.tokenizer.tokenizer import get_tokenizer
class TTSFrontEnd:
def __init__(
self,
speech_embedding_model: str,
speech_tokenizer_model: str,
):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.tokenizer = get_tokenizer()
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
self.speech_embedding_session = onnxruntime.InferenceSession(
speech_embedding_model,
sess_options=option,
providers=["CPUExecutionProvider"]
)
self.speech_tokenizer_session = onnxruntime.InferenceSession(
speech_tokenizer_model,
sess_options=option,
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else "CPUExecutionProvider"]
)
self.spk2info = {}
def _extract_text_token(self, text: str):
text_token = self.tokenizer.encode(text, allowed_special='all')
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
return text_token, text_token_len
def _extract_speech_token(self, speech: torch.Tensor):
if speech.shape[1] / 16000 > 30:
speech = speech[:, :int(16000 * 30)]
feat = whisper.log_mel_spectrogram(speech, n_mels=128)
speech_token = self.speech_tokenizer_session.run(
None,
{self.speech_tokenizer_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
self.speech_tokenizer_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)}
)[0].flatten().tolist()
speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
return speech_token, speech_token_len
def _extract_spk_embedding(self, speech: torch.Tensor):
feat = kaldi.fbank(
waveform=speech,
num_mel_bins=80,
dither=0,
sample_frequency=16000
)
feat = feat - feat.mean(dim=0, keepdim=True)
embedding = self.speech_embedding_session.run(
None,
{self.speech_embedding_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()}
)[0].flatten().tolist()
embedding = torch.tensor([embedding]).to(self.device)
return embedding
def _extract_speech_feat(self, speech: torch.Tensor):
speech_feat = mel_spectrogram(
y=speech,
n_fft=1024,
num_mels=80,
sampling_rate=22050,
hop_size=256,
win_size=1024,
fmin=0,
fmax=8000,
center=False
).squeeze(dim=0).transpose(0, 1).to(self.device)
speech_feat = speech_feat.unsqueeze(dim=0)
speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
return speech_feat, speech_feat_len
def preprocess_text(self, text, split=True) -> Union[str, List[str]]:
text = normalize_text(text)
if split:
text = list(split_text(
text=text,
tokenize=partial(self.tokenizer.encode, allowed_special='all'),
token_max_n=30,
token_min_n=10,
merge_len=5,
comma_split=False
))
return text
def frontend_tts(
self,
text: str,
prompt_speech_16k: Union[np.ndarray, torch.Tensor]
) -> dict:
if isinstance(prompt_speech_16k, np.ndarray):
prompt_speech_16k = torch.from_numpy(prompt_speech_16k)
text_token, text_token_len = self._extract_text_token(text)
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_22050)
embedding = self._extract_spk_embedding(prompt_speech_16k)
model_input = {
'text': text_token,
'text_len': text_token_len,
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
'prompt_speech_feat': speech_feat,
'prompt_speech_feat_len': speech_feat_len,
'llm_embedding': embedding,
'flow_embedding': embedding
}
return model_input
def frontend_vc(
self,
source_speech_16k: Union[np.ndarray, torch.Tensor],
prompt_speech_16k: Union[np.ndarray, torch.Tensor]
) -> dict:
if isinstance(source_speech_16k, np.ndarray):
source_speech_16k = torch.from_numpy(source_speech_16k)
if isinstance(prompt_speech_16k, np.ndarray):
prompt_speech_16k = torch.from_numpy(prompt_speech_16k)
prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_22050)
embedding = self._extract_spk_embedding(prompt_speech_16k)
source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
model_input = {
'source_speech_token': source_speech_token,
'source_speech_token_len': source_speech_token_len,
'flow_prompt_speech_token': prompt_speech_token,
'flow_prompt_speech_token_len': prompt_speech_token_len,
'prompt_speech_feat': prompt_speech_feat,
'prompt_speech_feat_len': prompt_speech_feat_len,
'flow_embedding': embedding
}
return model_input
|