File size: 2,287 Bytes
855c74b
ca14807
 
855c74b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca14807
855c74b
 
 
 
ca14807
 
 
 
 
 
 
 
 
855c74b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from functools import lru_cache
import os
import tempfile
from pathlib import Path

import sherpa_onnx
from huggingface_hub import hf_hub_download, snapshot_download

ENGLISH_REPO_ID = "vidhi0405/TextToSpeech"


def _normalize_repo_id(repo_id: str) -> str:
    v = repo_id.strip()
    if v.startswith("https://huggingface.co/"):
        v = v.removeprefix("https://huggingface.co/").strip("/")
    return v


def _get_file(repo_id: str, filename: str, subfolder: str) -> str:
    return hf_hub_download(
        repo_id=repo_id,
        filename=filename,
        subfolder=subfolder,
    )


@lru_cache(maxsize=2)
def get_pretrained_model(repo_id: str, speed: float) -> sherpa_onnx.OfflineTts:
    source_repo = _normalize_repo_id(repo_id)
    if source_repo != ENGLISH_REPO_ID:
        raise ValueError(f"Unsupported repo_id: {repo_id}. Use {ENGLISH_REPO_ID}")

    model = _get_file(
        repo_id=source_repo,
        filename="model.onnx",
        subfolder="kokoro-en-v0_19",
    )
    tokens_raw = _get_file(
        repo_id=source_repo,
        filename="tokens.txt",
        subfolder="kokoro-en-v0_19",
    )

    # Sanitize tokens file to prevent parsing errors (e.g. empty lines)
    with open(tokens_raw, "r", encoding="utf-8") as f:
        lines = [line for line in f if line.strip()]

    fd, tokens = tempfile.mkstemp(suffix=".txt", text=True)
    with os.fdopen(fd, "w", encoding="utf-8") as f:
        f.writelines(lines)

    voices = _get_file(
        repo_id=source_repo,
        filename="voices.bin",
        subfolder="kokoro-en-v0_19",
    )

    root_dir = snapshot_download(
        repo_id=source_repo,
        allow_patterns=["kokoro-en-v0_19/espeak-ng-data/*"],
    )
    data_dir = str(Path(root_dir) / "kokoro-en-v0_19" / "espeak-ng-data")

    tts_config = sherpa_onnx.OfflineTtsConfig(
        model=sherpa_onnx.OfflineTtsModelConfig(
            kokoro=sherpa_onnx.OfflineTtsKokoroModelConfig(
                model=model,
                voices=voices,
                tokens=tokens,
                data_dir=data_dir,
                length_scale=1.0 / speed,
            ),
            provider="cpu",
            debug=True,
            num_threads=2,
        ),
        max_num_sentences=1,
    )
    return sherpa_onnx.OfflineTts(tts_config)