| import os | |
| import numpy as np | |
| from tqdm import tqdm | |
| from loguru import logger | |
| from hyperpyyaml import load_hyperpyyaml | |
| from VietTTS.model import TTSModel | |
| from VietTTS.frontend import TTSFrontEnd | |
| from VietTTS.utils.file_utils import download_model, save_wav | |
| class TTS: | |
| def __init__(self, model_dir, load_jit=False, load_onnx=False): | |
| if not os.path.exists(model_dir): | |
| logger.info(f"Downloading model from huggingface [dangvansam/viet-tts]") | |
| download_model(model_dir) | |
| with open(f"{model_dir}/config.yaml", "r") as f: | |
| configs = load_hyperpyyaml(f) | |
| self.frontend = TTSFrontEnd( | |
| speech_embedding_model=f"{model_dir}/speech_embedding.onnx", speech_tokenizer_model=f"{model_dir}/speech_tokenizer.onnx" | |
| ) | |
| self.model = TTSModel(llm=configs["llm"], flow=configs["flow"], hift=configs["hift"]) | |
| self.model.load(llm_model=f"{model_dir}/llm.pt", flow_model=f"{model_dir}/flow.pt", hift_model=f"{model_dir}/hift.pt") | |
| if load_jit: | |
| self.model.load_jit( | |
| "{}/llm.text_encoder.fp16.zip".format(model_dir), | |
| "{}/llm.llm.fp16.zip".format(model_dir), | |
| "{}/flow.encoder.fp32.zip".format(model_dir), | |
| ) | |
| logger.success("Loaded jit model from {}".format(model_dir)) | |
| if load_onnx: | |
| self.model.load_onnx("{}/flow.decoder.estimator.fp32.onnx".format(model_dir)) | |
| logger.success("Loaded onnx model from {}".format(model_dir)) | |
| logger.success("Loaded model from {}".format(model_dir)) | |
| self.model_dir = model_dir | |
| def list_avaliable_spks(self): | |
| spks = list(self.frontend.spk2info.keys()) | |
| return spks | |
| def inference_tts(self, tts_text, prompt_speech_16k, stream=False, speed=1.0): | |
| for i in tqdm(self.frontend.preprocess_text(tts_text, split=True)): | |
| model_input = self.frontend.frontend_tts(i, prompt_speech_16k) | |
| for model_output in self.model.tts(**model_input, stream=stream, speed=speed): | |
| yield model_output | |
| def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0): | |
| model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k) | |
| for model_output in self.model.vc(**model_input, stream=stream, speed=speed): | |
| yield model_output | |
| def tts_to_wav(self, text, prompt_speech_16k, speed=1.0): | |
| wavs = [] | |
| for output in self.inference_tts(text, prompt_speech_16k, stream=False, speed=speed): | |
| wavs.append(output["tts_speech"].squeeze(0).numpy()) | |
| return np.concatenate(wavs, axis=0) | |
| def tts_to_file(self, text, prompt_speech_16k, speed, output_path): | |
| wav = self.tts_to_wav(text, prompt_speech_16k, speed) | |
| save_wav(wav, 22050, output_path) | |