yujuanqin's picture
support test_models on Intel
b295d06
from pathlib import Path
from kokoro_onnx import Kokoro
from misaki import espeak, ja, en, zh
from misaki.espeak import EspeakG2P
import re
from functools import lru_cache
from loguru import logger
import onnxruntime
import os
from lib.utils import Timer, write_audio
providers = onnxruntime.get_available_providers()
print(f"Available onnx runtime providers: {providers}")
MODEL_DIR = Path(r"D:\yujuan\yoyo-translator-win\models\kokoro")
def create_session(model_path):
# See list of providers https://github.com/microsoft/onnxruntime/issues/22101#issuecomment-2357667377
providers = onnxruntime.get_available_providers()
print(f"Available onnx runtime providers: {providers}")
# See session options https://onnxruntime.ai/docs/performance/tune-performance/threading.html#thread-management
sess_options = onnxruntime.SessionOptions()
cpu_count = os.cpu_count() // 2
print(f"Setting threads to CPU cores count: {cpu_count}")
sess_options.intra_op_num_threads = cpu_count
session = onnxruntime.InferenceSession(
model_path, providers=providers, sess_options=sess_options
)
return session
class KokoroTTS:
language_voice_mapping = {
"JP": "jf_alpha",
"JA": "jf_alpha",
"ZH": "zf_xiaoyi",
"EN": "af_heart",
"FR": "ff_siwis",
"IT": "im_nicola",
"HI": "hf_alpha",
"PT": "im_nicola",
"ES": "im_nicola"
}
def __init__(self, model_path: str, voice_model_path: str, vocab_config=None, gcp=None, voice=None):
self._session = create_session(model_path)
self.model = Kokoro.from_session(self._session, voice_model_path, vocab_config=vocab_config)
self.g2p = gcp
self.voice = voice
@classmethod
def from_language(cls, language: str, model_dir: Path=MODEL_DIR):
model_path: str = str(model_dir/"kokoro-quant.onnx")
voice_model_path: str = str(model_dir/"voices-v1.0.bin")
voice = cls.language_voice_mapping.get(language.upper())
logger.info(f"[TTS] language: {language}")
if not voice:
raise ValueError(f"Unsupported language: {language}, voice: {voice}")
if language.upper() == "ZH":
tts = cls(model_path, voice_model_path, vocab_config=model_dir / "zh_config.json", gcp=zh.ZHG2P(),
voice=voice)
tts.generate("你好")
elif language.upper() in ['JP', 'JA']:
tts = cls(model_path, voice_model_path, vocab_config=model_dir / "ja_config.json", gcp=ja.JAG2P(),
voice=voice)
elif language.upper() == 'EN':
fallback = espeak.EspeakFallback(british=False)
tts = cls(model_path, voice_model_path, gcp=en.G2P(trf=False, british=False, fallback=fallback),
voice=voice)
tts.generate("hello")
elif language.upper() == "HI":
g2p = EspeakG2P(language="hi")
tts = cls(model_path, voice_model_path, gcp=g2p, voice=voice)
tts.generate("हेलो")
elif language.upper() == "IT":
g2p = EspeakG2P(language="it")
tts = cls(model_path, voice_model_path, gcp=g2p, voice=voice)
tts.generate("Ciao")
elif language.upper() == "PT":
g2p = EspeakG2P(language="pt-br")
tts = cls(model_path, voice_model_path, gcp=g2p, voice=voice)
tts.generate("Olá")
elif language.upper() == "ES":
g2p = EspeakG2P(language="es")
tts = cls(model_path, voice_model_path, gcp=g2p, voice=voice)
tts.generate("Hola")
elif language.upper() == "FR":
g2p = EspeakG2P(language="fr-fr")
tts = cls(model_path, voice_model_path, gcp=g2p, voice=voice)
tts.generate("Bonjour")
else:
tts = cls(model_path, voice_model_path, gcp=EspeakG2P(language.lower()), voice=voice)
return tts
def generate(self, text, speed=1.2):
with Timer("tts inference") as t:
phonemes, _ = self.g2p(text)
samples, sample_rate = self.model.create(phonemes, self.voice, is_phonemes=True, speed=speed)
return samples, sample_rate, t.duration
async def stream(self, text, speed=1.2):
phonemes, _ = self.g2p(text)
stream = self.model.create_stream(phonemes, self.voice, is_phonemes=True, speed=speed)
async for samples, sample_rate in stream:
yield samples, sample_rate
@lru_cache
def get_model(language):
return KokoroTTS.from_language(language=language, model_dir_path=resource_path('models/kokoro'))