yujuanqin's picture
support test_models on Intel
b295d06
import os
from pathlib import Path
from kokoro_onnx import Kokoro
from misaki import espeak, en, zh
from misaki.espeak import EspeakG2P
from logging import getLogger
import onnxruntime
from lib.utils import Timer, write_audio
logger = getLogger(__name__)
providers = onnxruntime.get_available_providers()
MODEL_DIR = Path("//Users/jeqin/work/code/Translator/python_server/moyoyo_asr_models/kokoro")
def create_session(model_path):
# See list of providers https://github.com/microsoft/onnxruntime/issues/22101#issuecomment-2357667377
providers = onnxruntime.get_available_providers()
print(f"Available onnx runtime providers: {providers}")
# See session options https://onnxruntime.ai/docs/performance/tune-performance/threading.html#thread-management
sess_options = onnxruntime.SessionOptions()
cpu_count = os.cpu_count() // 2
print(f"Setting threads to CPU cores count: {cpu_count}")
sess_options.intra_op_num_threads = cpu_count
session = onnxruntime.InferenceSession(
model_path, providers=["CPUExecutionProvider"], sess_options=sess_options
)
return session
class KokoroTTS:
language_voice_mapping = {
"JP": "jf_alpha",
"JA": "jf_alpha",
"ZH": "zf_xiaoyi",
"EN": "af_heart",
"FR": "ff_siwis",
"IT": "im_nicola",
"HI": "hf_alpha",
"PT": "im_nicola",
"ES": "im_nicola"
}
language_word_mapping = {
"ZH": "你好",
"EN": "hello",
"FR": "Bonjour",
"IT": "Ciao",
"HI": "हेलो",
"PT": "Olá",
"ES": "Hola"
}
def __init__(self, model_path: str, voice_model_path: str, vocab_config=None, gcp=None, voice=None):
self._session = create_session(model_path)
self.model = Kokoro.from_session(self._session, voice_model_path, vocab_config=vocab_config)
self.g2p = gcp
self.voice = voice
@classmethod
def from_language(cls, language: str, model_dir: Path=MODEL_DIR):
model_path: str = str(model_dir / "kokoro-quant.onnx")
voice_model_path: str = str(model_dir / "voices-v1.0.bin")
voice = cls.language_voice_mapping.get(language.upper())
warm_up_text = cls.language_word_mapping.get(language.upper())
logger.info(f"[TTS] language: {language}")
if not voice:
raise ValueError(f"Unsupported language: {language}, voice: {voice}")
vocab_config = None
if language.upper() == "ZH":
g2p = zh.ZHG2P()
vocab_config = model_dir / "zh_config.json"
elif language.upper() == 'EN':
fallback = espeak.EspeakFallback(british=False)
g2p = en.G2P(trf=False, british=False, fallback=fallback)
elif language.upper() == "HI":
g2p = EspeakG2P(language="hi")
elif language.upper() == "IT":
g2p = EspeakG2P(language="it")
elif language.upper() == "PT":
g2p = EspeakG2P(language="pt-br")
elif language.upper() == "ES":
g2p = EspeakG2P(language="es")
elif language.upper() == "FR":
g2p = EspeakG2P(language="fr-fr")
else:
g2p = EspeakG2P(language.lower())
with Timer("load tts"):
tts = cls(model_path, voice_model_path,vocab_config=vocab_config, gcp=g2p, voice=voice)
tts.generate(warm_up_text)
return tts
def generate(self, text, speed=1.2):
with Timer("tts inference") as t:
phonemes, _ = self.g2p(text)
samples, sample_rate = self.model.create(phonemes, self.voice, is_phonemes=True, speed=speed)
return samples, sample_rate, t.duration
# return librosa.resample(samples, target_sr=44100, orig_sr=sample_rate)
async def stream(self, text, speed=1.2):
phonemes, _ = self.g2p(text)
stream = self.model.create_stream(phonemes, self.voice, is_phonemes=True, speed=speed)
async for samples, sample_rate in stream:
yield samples, sample_rate
if __name__ == '__main__':
tts = KokoroTTS.from_language(language="ZH")
samples, sr, time_cost = tts.generate("今天天气怎么样?")
write_audio("tts_out.wav", samples, sr)
print(time_cost)