File size: 4,644 Bytes
1e495f3
 
b295d06
1e495f3
b295d06
1e495f3
b295d06
1e495f3
b295d06
1e495f3
 
 
b295d06
 
1e495f3
 
 
 
 
 
 
 
 
 
 
 
b295d06
1e495f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b295d06
 
1e495f3
 
 
 
 
b295d06
 
 
 
 
 
1e495f3
 
b295d06
 
 
1e495f3
 
b295d06
 
1e495f3
 
b295d06
 
1e495f3
 
b295d06
 
1e495f3
 
b295d06
 
1e495f3
 
b295d06
 
1e495f3
b295d06
1e495f3
 
 
 
 
 
b295d06
1e495f3
 
 
 
 
 
 
 
 
b295d06
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from pathlib import Path
from kokoro_onnx import Kokoro
from misaki import espeak, ja, en, zh
from misaki.espeak import EspeakG2P
import re
from functools import lru_cache
from loguru import logger
import onnxruntime
import os
from lib.utils import Timer, write_audio

providers = onnxruntime.get_available_providers()
print(f"Available onnx runtime providers: {providers}")
MODEL_DIR = Path(r"D:\yujuan\yoyo-translator-win\models\kokoro")

def create_session(model_path):
    # See list of providers https://github.com/microsoft/onnxruntime/issues/22101#issuecomment-2357667377
    providers = onnxruntime.get_available_providers()
    print(f"Available onnx runtime providers: {providers}")

    # See session options https://onnxruntime.ai/docs/performance/tune-performance/threading.html#thread-management
    sess_options = onnxruntime.SessionOptions()
    cpu_count = os.cpu_count() // 2
    print(f"Setting threads to CPU cores count: {cpu_count}")
    sess_options.intra_op_num_threads = cpu_count
    session = onnxruntime.InferenceSession(
        model_path, providers=providers, sess_options=sess_options
    )
    return session


class KokoroTTS:
    language_voice_mapping = {
        "JP": "jf_alpha",
        "JA": "jf_alpha",
        "ZH": "zf_xiaoyi",
        "EN": "af_heart",
        "FR": "ff_siwis",
        "IT": "im_nicola",
        "HI": "hf_alpha",
        "PT": "im_nicola",
        "ES": "im_nicola"
    }

    def __init__(self, model_path: str, voice_model_path: str, vocab_config=None, gcp=None, voice=None):
        self._session = create_session(model_path)
        self.model = Kokoro.from_session(self._session, voice_model_path, vocab_config=vocab_config)
        self.g2p = gcp
        self.voice = voice

    @classmethod
    def from_language(cls, language: str, model_dir: Path=MODEL_DIR):
        model_path: str = str(model_dir/"kokoro-quant.onnx")
        voice_model_path: str = str(model_dir/"voices-v1.0.bin")
        voice = cls.language_voice_mapping.get(language.upper())
        logger.info(f"[TTS] language: {language}")
        if not voice:
            raise ValueError(f"Unsupported language: {language}, voice: {voice}")
        if language.upper() == "ZH":
            tts = cls(model_path, voice_model_path, vocab_config=model_dir / "zh_config.json", gcp=zh.ZHG2P(),
                      voice=voice)
            tts.generate("你好")
        elif language.upper() in ['JP', 'JA']:
            tts = cls(model_path, voice_model_path, vocab_config=model_dir / "ja_config.json", gcp=ja.JAG2P(),
                      voice=voice)
        elif language.upper() == 'EN':
            fallback = espeak.EspeakFallback(british=False)
            tts = cls(model_path, voice_model_path, gcp=en.G2P(trf=False, british=False, fallback=fallback),
                      voice=voice)
            tts.generate("hello")
        elif language.upper() == "HI":
            g2p = EspeakG2P(language="hi")
            tts = cls(model_path, voice_model_path, gcp=g2p, voice=voice)
            tts.generate("हेलो")
        elif language.upper() == "IT":
            g2p = EspeakG2P(language="it")
            tts = cls(model_path, voice_model_path, gcp=g2p, voice=voice)
            tts.generate("Ciao")
        elif language.upper() == "PT":
            g2p = EspeakG2P(language="pt-br")
            tts = cls(model_path, voice_model_path, gcp=g2p, voice=voice)
            tts.generate("Olá")
        elif language.upper() == "ES":
            g2p = EspeakG2P(language="es")
            tts = cls(model_path, voice_model_path, gcp=g2p, voice=voice)
            tts.generate("Hola")
        elif language.upper() == "FR":
            g2p = EspeakG2P(language="fr-fr")
            tts = cls(model_path, voice_model_path, gcp=g2p, voice=voice)
            tts.generate("Bonjour")
        else:
            tts = cls(model_path, voice_model_path, gcp=EspeakG2P(language.lower()), voice=voice)
        return tts

    def generate(self, text, speed=1.2):
        with Timer("tts inference") as t:
            phonemes, _ = self.g2p(text)
            samples, sample_rate = self.model.create(phonemes, self.voice, is_phonemes=True, speed=speed)

        return samples, sample_rate, t.duration

    async def stream(self, text, speed=1.2):
        phonemes, _ = self.g2p(text)
        stream = self.model.create_stream(phonemes, self.voice, is_phonemes=True, speed=speed)
        async for samples, sample_rate in stream:
            yield samples, sample_rate


@lru_cache
def get_model(language):
    return KokoroTTS.from_language(language=language, model_dir_path=resource_path('models/kokoro'))