File size: 4,253 Bytes
1e495f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
from pathlib import Path
from kokoro_onnx import Kokoro
from misaki import espeak, en, zh
from misaki.espeak import EspeakG2P
from logging import getLogger
import onnxruntime

from lib.utils import Timer, write_audio


logger = getLogger(__name__)
providers = onnxruntime.get_available_providers()
MODEL_DIR = Path("//Users/jeqin/work/code/Translator/python_server/moyoyo_asr_models/kokoro")

def create_session(model_path):
    # See list of providers https://github.com/microsoft/onnxruntime/issues/22101#issuecomment-2357667377
    providers = onnxruntime.get_available_providers()
    print(f"Available onnx runtime providers: {providers}")

    # See session options https://onnxruntime.ai/docs/performance/tune-performance/threading.html#thread-management
    sess_options = onnxruntime.SessionOptions()
    cpu_count = os.cpu_count() // 2
    print(f"Setting threads to CPU cores count: {cpu_count}")
    sess_options.intra_op_num_threads = cpu_count
    session = onnxruntime.InferenceSession(
        model_path, providers=["CPUExecutionProvider"], sess_options=sess_options
    )
    return session


class KokoroTTS:
    language_voice_mapping = {
        "JP": "jf_alpha",
        "JA": "jf_alpha",
        "ZH": "zf_xiaoyi",
        "EN": "af_heart",
        "FR": "ff_siwis",
        "IT": "im_nicola",
        "HI": "hf_alpha",
        "PT": "im_nicola",
        "ES": "im_nicola"
    }
    language_word_mapping = {
        "ZH": "你好",
        "EN": "hello",
        "FR": "Bonjour",
        "IT": "Ciao",
        "HI": "हेलो",
        "PT": "Olá",
        "ES": "Hola"
    }

    def __init__(self, model_path: str, voice_model_path: str, vocab_config=None, gcp=None, voice=None):
        self._session = create_session(model_path)
        self.model = Kokoro.from_session(self._session, voice_model_path, vocab_config=vocab_config)
        self.g2p = gcp
        self.voice = voice

    @classmethod
    def from_language(cls, language: str, model_dir: Path=MODEL_DIR):
        model_path: str = str(model_dir / "kokoro-quant.onnx")
        voice_model_path: str = str(model_dir / "voices-v1.0.bin")
        voice = cls.language_voice_mapping.get(language.upper())
        warm_up_text = cls.language_word_mapping.get(language.upper())
        logger.info(f"[TTS] language: {language}")
        if not voice:
            raise ValueError(f"Unsupported language: {language}, voice: {voice}")
        vocab_config = None
        if language.upper() == "ZH":
            g2p = zh.ZHG2P()
            vocab_config = model_dir / "zh_config.json"
        elif language.upper() == 'EN':
            fallback = espeak.EspeakFallback(british=False)
            g2p = en.G2P(trf=False, british=False, fallback=fallback)
        elif language.upper() == "HI":
            g2p = EspeakG2P(language="hi")
        elif language.upper() == "IT":
            g2p = EspeakG2P(language="it")
        elif language.upper() == "PT":
            g2p = EspeakG2P(language="pt-br")
        elif language.upper() == "ES":
            g2p = EspeakG2P(language="es")
        elif language.upper() == "FR":
            g2p = EspeakG2P(language="fr-fr")
        else:
            g2p = EspeakG2P(language.lower())
        with Timer("load tts"):
            tts = cls(model_path, voice_model_path,vocab_config=vocab_config, gcp=g2p, voice=voice)
        tts.generate(warm_up_text)
        return tts

    def generate(self, text, speed=1.2):
        with Timer("tts inference") as t:
            phonemes, _ = self.g2p(text)
            samples, sample_rate = self.model.create(phonemes, self.voice, is_phonemes=True, speed=speed)
        return samples, sample_rate, t.duration
        # return librosa.resample(samples, target_sr=44100, orig_sr=sample_rate)

    async def stream(self, text, speed=1.2):
        phonemes, _ = self.g2p(text)
        stream = self.model.create_stream(phonemes, self.voice, is_phonemes=True, speed=speed)
        async for samples, sample_rate in stream:
            yield samples, sample_rate


if __name__ == '__main__':
    tts = KokoroTTS.from_language(language="ZH")
    samples, sr, time_cost = tts.generate("今天天气怎么样?")
    write_audio("tts_out.wav", samples, sr)
    print(time_cost)