from pathlib import Path from kokoro_onnx import Kokoro from misaki import espeak, ja, en, zh from misaki.espeak import EspeakG2P import re from functools import lru_cache from loguru import logger import onnxruntime import os from lib.utils import Timer, write_audio providers = onnxruntime.get_available_providers() print(f"Available onnx runtime providers: {providers}") MODEL_DIR = Path(r"D:\yujuan\yoyo-translator-win\models\kokoro") def create_session(model_path): # See list of providers https://github.com/microsoft/onnxruntime/issues/22101#issuecomment-2357667377 providers = onnxruntime.get_available_providers() print(f"Available onnx runtime providers: {providers}") # See session options https://onnxruntime.ai/docs/performance/tune-performance/threading.html#thread-management sess_options = onnxruntime.SessionOptions() cpu_count = os.cpu_count() // 2 print(f"Setting threads to CPU cores count: {cpu_count}") sess_options.intra_op_num_threads = cpu_count session = onnxruntime.InferenceSession( model_path, providers=providers, sess_options=sess_options ) return session class KokoroTTS: language_voice_mapping = { "JP": "jf_alpha", "JA": "jf_alpha", "ZH": "zf_xiaoyi", "EN": "af_heart", "FR": "ff_siwis", "IT": "im_nicola", "HI": "hf_alpha", "PT": "im_nicola", "ES": "im_nicola" } def __init__(self, model_path: str, voice_model_path: str, vocab_config=None, gcp=None, voice=None): self._session = create_session(model_path) self.model = Kokoro.from_session(self._session, voice_model_path, vocab_config=vocab_config) self.g2p = gcp self.voice = voice @classmethod def from_language(cls, language: str, model_dir: Path=MODEL_DIR): model_path: str = str(model_dir/"kokoro-quant.onnx") voice_model_path: str = str(model_dir/"voices-v1.0.bin") voice = cls.language_voice_mapping.get(language.upper()) logger.info(f"[TTS] language: {language}") if not voice: raise ValueError(f"Unsupported language: {language}, voice: {voice}") if language.upper() == "ZH": tts = cls(model_path, voice_model_path, vocab_config=model_dir / "zh_config.json", gcp=zh.ZHG2P(), voice=voice) tts.generate("你好") elif language.upper() in ['JP', 'JA']: tts = cls(model_path, voice_model_path, vocab_config=model_dir / "ja_config.json", gcp=ja.JAG2P(), voice=voice) elif language.upper() == 'EN': fallback = espeak.EspeakFallback(british=False) tts = cls(model_path, voice_model_path, gcp=en.G2P(trf=False, british=False, fallback=fallback), voice=voice) tts.generate("hello") elif language.upper() == "HI": g2p = EspeakG2P(language="hi") tts = cls(model_path, voice_model_path, gcp=g2p, voice=voice) tts.generate("हेलो") elif language.upper() == "IT": g2p = EspeakG2P(language="it") tts = cls(model_path, voice_model_path, gcp=g2p, voice=voice) tts.generate("Ciao") elif language.upper() == "PT": g2p = EspeakG2P(language="pt-br") tts = cls(model_path, voice_model_path, gcp=g2p, voice=voice) tts.generate("Olá") elif language.upper() == "ES": g2p = EspeakG2P(language="es") tts = cls(model_path, voice_model_path, gcp=g2p, voice=voice) tts.generate("Hola") elif language.upper() == "FR": g2p = EspeakG2P(language="fr-fr") tts = cls(model_path, voice_model_path, gcp=g2p, voice=voice) tts.generate("Bonjour") else: tts = cls(model_path, voice_model_path, gcp=EspeakG2P(language.lower()), voice=voice) return tts def generate(self, text, speed=1.2): with Timer("tts inference") as t: phonemes, _ = self.g2p(text) samples, sample_rate = self.model.create(phonemes, self.voice, is_phonemes=True, speed=speed) return samples, sample_rate, t.duration async def stream(self, text, speed=1.2): phonemes, _ = self.g2p(text) stream = self.model.create_stream(phonemes, self.voice, is_phonemes=True, speed=speed) async for samples, sample_rate in stream: yield samples, sample_rate @lru_cache def get_model(language): return KokoroTTS.from_language(language=language, model_dir_path=resource_path('models/kokoro'))