|
|
from pathlib import Path |
|
|
from kokoro_onnx import Kokoro |
|
|
from misaki import espeak, ja, en, zh |
|
|
from misaki.espeak import EspeakG2P |
|
|
import re |
|
|
from functools import lru_cache |
|
|
from loguru import logger |
|
|
import onnxruntime |
|
|
import os |
|
|
from lib.utils import Timer, write_audio |
|
|
|
|
|
providers = onnxruntime.get_available_providers() |
|
|
print(f"Available onnx runtime providers: {providers}") |
|
|
MODEL_DIR = Path(r"D:\yujuan\yoyo-translator-win\models\kokoro") |
|
|
|
|
|
def create_session(model_path): |
|
|
|
|
|
providers = onnxruntime.get_available_providers() |
|
|
print(f"Available onnx runtime providers: {providers}") |
|
|
|
|
|
|
|
|
sess_options = onnxruntime.SessionOptions() |
|
|
cpu_count = os.cpu_count() // 2 |
|
|
print(f"Setting threads to CPU cores count: {cpu_count}") |
|
|
sess_options.intra_op_num_threads = cpu_count |
|
|
session = onnxruntime.InferenceSession( |
|
|
model_path, providers=providers, sess_options=sess_options |
|
|
) |
|
|
return session |
|
|
|
|
|
|
|
|
class KokoroTTS: |
|
|
language_voice_mapping = { |
|
|
"JP": "jf_alpha", |
|
|
"JA": "jf_alpha", |
|
|
"ZH": "zf_xiaoyi", |
|
|
"EN": "af_heart", |
|
|
"FR": "ff_siwis", |
|
|
"IT": "im_nicola", |
|
|
"HI": "hf_alpha", |
|
|
"PT": "im_nicola", |
|
|
"ES": "im_nicola" |
|
|
} |
|
|
|
|
|
def __init__(self, model_path: str, voice_model_path: str, vocab_config=None, gcp=None, voice=None): |
|
|
self._session = create_session(model_path) |
|
|
self.model = Kokoro.from_session(self._session, voice_model_path, vocab_config=vocab_config) |
|
|
self.g2p = gcp |
|
|
self.voice = voice |
|
|
|
|
|
@classmethod |
|
|
def from_language(cls, language: str, model_dir: Path=MODEL_DIR): |
|
|
model_path: str = str(model_dir/"kokoro-quant.onnx") |
|
|
voice_model_path: str = str(model_dir/"voices-v1.0.bin") |
|
|
voice = cls.language_voice_mapping.get(language.upper()) |
|
|
logger.info(f"[TTS] language: {language}") |
|
|
if not voice: |
|
|
raise ValueError(f"Unsupported language: {language}, voice: {voice}") |
|
|
if language.upper() == "ZH": |
|
|
tts = cls(model_path, voice_model_path, vocab_config=model_dir / "zh_config.json", gcp=zh.ZHG2P(), |
|
|
voice=voice) |
|
|
tts.generate("你好") |
|
|
elif language.upper() in ['JP', 'JA']: |
|
|
tts = cls(model_path, voice_model_path, vocab_config=model_dir / "ja_config.json", gcp=ja.JAG2P(), |
|
|
voice=voice) |
|
|
elif language.upper() == 'EN': |
|
|
fallback = espeak.EspeakFallback(british=False) |
|
|
tts = cls(model_path, voice_model_path, gcp=en.G2P(trf=False, british=False, fallback=fallback), |
|
|
voice=voice) |
|
|
tts.generate("hello") |
|
|
elif language.upper() == "HI": |
|
|
g2p = EspeakG2P(language="hi") |
|
|
tts = cls(model_path, voice_model_path, gcp=g2p, voice=voice) |
|
|
tts.generate("हेलो") |
|
|
elif language.upper() == "IT": |
|
|
g2p = EspeakG2P(language="it") |
|
|
tts = cls(model_path, voice_model_path, gcp=g2p, voice=voice) |
|
|
tts.generate("Ciao") |
|
|
elif language.upper() == "PT": |
|
|
g2p = EspeakG2P(language="pt-br") |
|
|
tts = cls(model_path, voice_model_path, gcp=g2p, voice=voice) |
|
|
tts.generate("Olá") |
|
|
elif language.upper() == "ES": |
|
|
g2p = EspeakG2P(language="es") |
|
|
tts = cls(model_path, voice_model_path, gcp=g2p, voice=voice) |
|
|
tts.generate("Hola") |
|
|
elif language.upper() == "FR": |
|
|
g2p = EspeakG2P(language="fr-fr") |
|
|
tts = cls(model_path, voice_model_path, gcp=g2p, voice=voice) |
|
|
tts.generate("Bonjour") |
|
|
else: |
|
|
tts = cls(model_path, voice_model_path, gcp=EspeakG2P(language.lower()), voice=voice) |
|
|
return tts |
|
|
|
|
|
def generate(self, text, speed=1.2): |
|
|
with Timer("tts inference") as t: |
|
|
phonemes, _ = self.g2p(text) |
|
|
samples, sample_rate = self.model.create(phonemes, self.voice, is_phonemes=True, speed=speed) |
|
|
|
|
|
return samples, sample_rate, t.duration |
|
|
|
|
|
async def stream(self, text, speed=1.2): |
|
|
phonemes, _ = self.g2p(text) |
|
|
stream = self.model.create_stream(phonemes, self.voice, is_phonemes=True, speed=speed) |
|
|
async for samples, sample_rate in stream: |
|
|
yield samples, sample_rate |
|
|
|
|
|
|
|
|
@lru_cache |
|
|
def get_model(language): |
|
|
return KokoroTTS.from_language(language=language, model_dir_path=resource_path('models/kokoro')) |
|
|
|