so / model.py
leesenx's picture
Update model.py
707dad3 verified
import os
import json
from functools import lru_cache
from pathlib import Path
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1")
import sherpa_onnx
from huggingface_hub import hf_hub_download, snapshot_download
def get_file(repo_id: str, filename: str, subfolder: str = ".") -> str:
return hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder)
@lru_cache(maxsize=10)
def _get_kokoro(repo_id: str, speed: float) -> sherpa_onnx.OfflineTts:
data_dir = "/tmp/espeak-ng-data"
clean_id = repo_id.split("|")[0]
is_int8 = "int8" in clean_id
model_name = "model.int8.onnx" if is_int8 else "model.onnx"
model = get_file(repo_id=clean_id, filename=model_name, subfolder=".")
tokens = get_file(repo_id=clean_id, filename="tokens.txt", subfolder=".")
voices = get_file(repo_id=clean_id, filename="voices.bin", subfolder=".")
if "multi-lang" in clean_id:
lexicon_en = get_file(repo_id=clean_id, filename="lexicon-us-en.txt", subfolder=".")
lexicon_zh = get_file(repo_id=clean_id, filename="lexicon-zh.txt", subfolder=".")
lexicon = f"{lexicon_en},{lexicon_zh}"
date_zh = get_file(repo_id=clean_id, filename="date-zh.fst", subfolder=".")
number_zh = get_file(repo_id=clean_id, filename="number-zh.fst", subfolder=".")
phone_zh = get_file(repo_id=clean_id, filename="phone-zh.fst", subfolder=".")
rule_fsts = f"{date_zh},{phone_zh},{number_zh}"
dict_dir = "/tmp/dict"
else:
lexicon = ""
rule_fsts = ""
dict_dir = ""
tts_config = sherpa_onnx.OfflineTtsConfig(
model=sherpa_onnx.OfflineTtsModelConfig(
kokoro=sherpa_onnx.OfflineTtsKokoroModelConfig(
model=model,
voices=voices,
tokens=tokens,
data_dir=data_dir,
length_scale=1.0 / speed,
lexicon=lexicon,
dict_dir=dict_dir,
),
provider="cpu",
debug=False,
num_threads=2,
),
max_num_sentences=1,
rule_fsts=rule_fsts,
)
return sherpa_onnx.OfflineTts(tts_config)
@lru_cache(maxsize=10)
def _get_supertonic(repo_id: str, speed: float) -> sherpa_onnx.OfflineTts:
clean_id = repo_id.split("|")[0]
dp = get_file(repo_id=clean_id, filename="duration_predictor.int8.onnx", subfolder=".")
te = get_file(repo_id=clean_id, filename="text_encoder.int8.onnx", subfolder=".")
ve = get_file(repo_id=clean_id, filename="vector_estimator.int8.onnx", subfolder=".")
vo = get_file(repo_id=clean_id, filename="vocoder.int8.onnx", subfolder=".")
tts_json = get_file(repo_id=clean_id, filename="tts.json", subfolder=".")
ui = get_file(repo_id=clean_id, filename="unicode_indexer.bin", subfolder=".")
vs = get_file(repo_id=clean_id, filename="voice.bin", subfolder=".")
tts_config = sherpa_onnx.OfflineTtsConfig(
model=sherpa_onnx.OfflineTtsModelConfig(
supertonic=sherpa_onnx.OfflineTtsSupertonicModelConfig(
duration_predictor=dp,
text_encoder=te,
vector_estimator=ve,
vocoder=vo,
tts_json=tts_json,
unicode_indexer=ui,
voice_style=vs,
),
provider="cpu",
debug=False,
num_threads=2,
),
max_num_sentences=1,
)
return sherpa_onnx.OfflineTts(tts_config)
@lru_cache(maxsize=10)
def _get_vits_piper(repo_id: str, speed: float) -> sherpa_onnx.OfflineTts:
data_dir = "/tmp/espeak-ng-data"
lexicon = ""
rule_fsts = ""
clean_id = repo_id.split("|")[0]
if "piper" in clean_id:
n = len("vits-piper-")
name = clean_id.split("/")[1][n:]
elif "mimic3" in clean_id:
n = len("vits-mimic3-")
name = clean_id.split("/")[1][n:]
else:
name = "model"
local_dir = snapshot_download(clean_id)
model = f"{local_dir}/{name}.onnx"
tokens = f"{local_dir}/tokens.txt"
if "vits-piper-zh_CN-chaowen-medium" in clean_id or "vits-piper-zh_CN-xiao_ya-medium" in clean_id:
data_dir = ""
lexicon = f"{local_dir}/lexicon.txt"
rule_fsts = ["phone.fst", "date.fst", "number.fst"]
rule_fsts = ",".join(f"{local_dir}/{r}" for r in rule_fsts)
tts_config = sherpa_onnx.OfflineTtsConfig(
model=sherpa_onnx.OfflineTtsModelConfig(
vits=sherpa_onnx.OfflineTtsVitsModelConfig(
model=model,
lexicon=lexicon,
data_dir=data_dir,
tokens=tokens,
length_scale=1.0 / speed,
),
matcha=sherpa_onnx.OfflineTtsMatchaModelConfig(),
provider="cpu",
debug=False,
num_threads=2,
),
max_num_sentences=1,
rule_fsts=rule_fsts,
)
return sherpa_onnx.OfflineTts(tts_config)
@lru_cache(maxsize=10)
def _get_vits_zh_aishell3(repo_id: str, speed: float) -> sherpa_onnx.OfflineTts:
clean_id = repo_id.split("|")[0]
model = get_file(repo_id=clean_id, filename="vits-aishell3.onnx", subfolder=".")
lexicon = get_file(repo_id=clean_id, filename="lexicon.txt", subfolder=".")
tokens = get_file(repo_id=clean_id, filename="tokens.txt", subfolder=".")
rule_fsts = ["phone.fst", "date.fst", "number.fst", "new_heteronym.fst"]
rule_fsts = ",".join(get_file(repo_id=clean_id, filename=f, subfolder=".") for f in rule_fsts)
rule_fars = get_file(repo_id=clean_id, filename="rule.far", subfolder=".")
tts_config = sherpa_onnx.OfflineTtsConfig(
model=sherpa_onnx.OfflineTtsModelConfig(
vits=sherpa_onnx.OfflineTtsVitsModelConfig(
model=model, lexicon=lexicon, tokens=tokens, length_scale=1.0 / speed,
),
matcha=sherpa_onnx.OfflineTtsMatchaModelConfig(),
provider="cpu", debug=False, num_threads=2,
),
rule_fsts=rule_fsts, rule_fars=rule_fars, max_num_sentences=1,
)
return sherpa_onnx.OfflineTts(tts_config)
@lru_cache(maxsize=10)
def _get_matcha_zh_en(repo_id: str, speed: float) -> sherpa_onnx.OfflineTts:
clean_id = repo_id.split("|")[0]
acoustic_model = get_file(repo_id=clean_id, filename="model-steps-3.onnx", subfolder=".")
vocoder = get_file(repo_id="csukuangfj/sherpa-onnx-vocoders", filename="vocos-16khz-univ.onnx", subfolder=".")
lexicon = get_file(repo_id=clean_id, filename="lexicon.txt", subfolder=".")
tokens = get_file(repo_id=clean_id, filename="tokens.txt", subfolder=".")
rule_fsts = ["phone-zh.fst", "date-zh.fst", "number-zh.fst"]
rule_fsts = ",".join(get_file(repo_id=clean_id, filename=f, subfolder=".") for f in rule_fsts)
tts_config = sherpa_onnx.OfflineTtsConfig(
model=sherpa_onnx.OfflineTtsModelConfig(
vits=sherpa_onnx.OfflineTtsVitsModelConfig(),
matcha=sherpa_onnx.OfflineTtsMatchaModelConfig(
acoustic_model=acoustic_model, vocoder=vocoder,
lexicon=lexicon, tokens=tokens, data_dir="/tmp/espeak-ng-data",
length_scale=1.0 / speed,
),
provider="cpu", debug=False, num_threads=2,
),
rule_fsts=rule_fsts, rule_fars="", max_num_sentences=1,
)
return sherpa_onnx.OfflineTts(tts_config)
@lru_cache(maxsize=10)
def _get_matcha_hf(repo_id: str, speed: float) -> sherpa_onnx.OfflineTts:
clean_id = repo_id.split("|")[0]
if not Path("/tmp/dict").is_dir():
os.system("cd /tmp; curl -SL -O https://github.com/csukuangfj/cppjieba/releases/download/sherpa-onnx-2024-04-19/dict.tar.bz2; tar xf dict.tar.bz2")
acoustic_model = get_file(repo_id=clean_id, filename="model-steps-3.onnx", subfolder=".")
vocoder = get_file(repo_id="csukuangfj/sherpa-onnx-hifigan", filename="hifigan_v2.onnx", subfolder=".")
lexicon = get_file(repo_id=clean_id, filename="lexicon.txt", subfolder=".")
tokens = get_file(repo_id=clean_id, filename="tokens.txt", subfolder=".")
rule_fsts = ["phone.fst", "date.fst", "number.fst"]
rule_fsts = ",".join(get_file(repo_id=clean_id, filename=f, subfolder=".") for f in rule_fsts)
tts_config = sherpa_onnx.OfflineTtsConfig(
model=sherpa_onnx.OfflineTtsModelConfig(
vits=sherpa_onnx.OfflineTtsVitsModelConfig(),
matcha=sherpa_onnx.OfflineTtsMatchaModelConfig(
acoustic_model=acoustic_model, vocoder=vocoder,
lexicon=lexicon, tokens=tokens, dict_dir="/tmp/dict",
length_scale=1.0 / speed,
),
provider="cpu", debug=False, num_threads=2,
),
rule_fsts=rule_fsts, rule_fars="", max_num_sentences=1,
)
return sherpa_onnx.OfflineTts(tts_config)
@lru_cache(maxsize=10)
def _get_vits_hf(repo_id: str, speed: float) -> sherpa_onnx.OfflineTts:
clean_id = repo_id.split("|")[0]
if "fanchen" in clean_id or "vits-cantonese-hf-xiaomaiiwn" in clean_id:
model = clean_id.split("/")[-1]
elif "vits-melo-tts" in clean_id:
model = "model"
else:
model = "model"
if "vits-zh-ll" in clean_id or "vits-melo-tts" in clean_id:
if not Path("/tmp/dict").is_dir():
os.system("cd /tmp; curl -SL -O https://github.com/csukuangfj/cppjieba/releases/download/sherpa-onnx-2024-04-19/dict.tar.bz2; tar xf dict.tar.bz2")
model_file = get_file(repo_id=clean_id, filename=f"{model}.onnx", subfolder=".")
lexicon = get_file(repo_id=clean_id, filename="lexicon.txt", subfolder=".")
tokens = get_file(repo_id=clean_id, filename="tokens.txt", subfolder=".")
if "vits-cantonese-hf-xiaomaiiwn" not in clean_id:
rule_fsts = ["phone.fst", "date.fst", "number.fst"]
rule_fsts = ",".join(get_file(repo_id=clean_id, filename=f, subfolder=".") for f in rule_fsts)
vits_dict_dir = "/tmp/dict"
else:
rule_fsts = get_file(repo_id=clean_id, filename="rule.fst", subfolder=".")
vits_dict_dir = ""
tts_config = sherpa_onnx.OfflineTtsConfig(
model=sherpa_onnx.OfflineTtsModelConfig(
vits=sherpa_onnx.OfflineTtsVitsModelConfig(
model=model_file, lexicon=lexicon, tokens=tokens,
dict_dir=vits_dict_dir, length_scale=1.0 / speed,
),
matcha=sherpa_onnx.OfflineTtsMatchaModelConfig(),
provider="cpu", debug=False, num_threads=2,
),
rule_fsts=rule_fsts, rule_fars="", max_num_sentences=1,
)
return sherpa_onnx.OfflineTts(tts_config)
@lru_cache(maxsize=10)
def _get_melotts_onnx(repo_id: str, speed: float) -> sherpa_onnx.OfflineTts:
clean_id = repo_id.split("|")[0]
parts = clean_id.split("/")
lang_subdir = parts[2] if len(parts) > 2 else parts[1]
model = hf_hub_download(repo_id="MiaoMint/MeloTTS-ONNX", filename=f"onnx_exports/{lang_subdir}/model.onnx")
tokens = hf_hub_download(repo_id="MiaoMint/MeloTTS-ONNX", filename=f"onnx_exports/{lang_subdir}/tokens.txt")
lexicon_path = ""
try:
lexicon_path = hf_hub_download(repo_id="MiaoMint/MeloTTS-ONNX", filename=f"onnx_exports/{lang_subdir}/lexicon.txt")
except Exception:
pass
data_dir = "/tmp/espeak-ng-data"
dict_dir = "/tmp/dict" if lang_subdir == "zh" else ""
if dict_dir and not Path(dict_dir).is_dir():
os.system("cd /tmp; curl -SL -O https://github.com/csukuangfj/cppjieba/releases/download/sherpa-onnx-2024-04-19/dict.tar.bz2; tar xf dict.tar.bz2")
tts_config = sherpa_onnx.OfflineTtsConfig(
model=sherpa_onnx.OfflineTtsModelConfig(
vits=sherpa_onnx.OfflineTtsVitsModelConfig(
model=model, lexicon=lexicon_path, data_dir=data_dir,
tokens=tokens, dict_dir=dict_dir, length_scale=1.0 / speed,
),
matcha=sherpa_onnx.OfflineTtsMatchaModelConfig(),
provider="cpu", debug=False, num_threads=2,
),
max_num_sentences=1,
)
return sherpa_onnx.OfflineTts(tts_config)
chinese_models = {
"csukuangfj2/vits-piper-zh_CN-chaowen-medium|1 speaker": _get_vits_piper,
"csukuangfj2/vits-piper-zh_CN-xiao_ya-medium|1 speaker": _get_vits_piper,
"csukuangfj/matcha-icefall-zh-baker|1 speaker": _get_matcha_hf,
"csukuangfj/vits-zh-hf-fanchen-wnj|1 speaker": _get_vits_hf,
"csukuangfj/vits-zh-hf-fanchen-C|187 speakers": _get_vits_hf,
"csukuangfj/sherpa-onnx-vits-zh-ll|5 speakers": _get_vits_hf,
"csukuangfj/vits-zh-aishell3|174 speakers": _get_vits_zh_aishell3,
"csukuangfj/vits-piper-zh_CN-huayan-medium|1 speaker": _get_vits_piper,
"MiaoMint/MeloTTS-ONNX/zh|1 speaker": _get_melotts_onnx,
}
chinese_english_models = {
"csukuangfj/matcha-icefall-zh-en|1": _get_matcha_zh_en,
"csukuangfj/kokoro-multi-lang-v1_1|103 speakers": _get_kokoro,
"csukuangfj/kokoro-int8-multi-lang-v1_1|103 speakers": _get_kokoro,
"csukuangfj/kokoro-multi-lang-v1_0|53 speakers": _get_kokoro,
"csukuangfj/kokoro-int8-multi-lang-v1_0|53 speakers": _get_kokoro,
"csukuangfj2/sherpa-onnx-supertonic-3-tts-int8-2026-05-11|10 speakers": _get_supertonic,
"csukuangfj/vits-melo-tts-zh_en|1": _get_vits_hf,
"MiaoMint/MeloTTS-ONNX/zh|1 speaker": _get_melotts_onnx,
}
cantonese_models = {
"csukuangfj/vits-cantonese-hf-xiaomaiiwn|1 speaker": _get_vits_hf,
}
english_models = {
"csukuangfj/kokoro-en-v0_19|11 speakers": _get_kokoro,
"csukuangfj2/sherpa-onnx-supertonic-3-tts-int8-2026-05-11|10 speakers": _get_supertonic,
"MiaoMint/MeloTTS-ONNX/en_newest|1 speaker": _get_melotts_onnx,
"csukuangfj/vits-piper-en_US-lessac-high|1 speaker": _get_vits_piper,
"csukuangfj/vits-piper-en_US-ryan-high|1 speaker": _get_vits_piper,
"csukuangfj/vits-piper-en_GB-alan-medium|1 speaker": _get_vits_piper,
"csukuangfj/vits-piper-en_GB-jenny_dioco-medium|1 speaker": _get_vits_piper,
"csukuangfj/vits-piper-en_GB-vctk-medium|109 speakers": _get_vits_piper,
}
language_to_models = {
"中文(普通话)": list(chinese_models.keys()),
"中英双语": list(chinese_english_models.keys()),
"粤语": list(cantonese_models.keys()),
"英语": list(english_models.keys()),
}
all_model_dicts = {
**chinese_models,
**chinese_english_models,
**cantonese_models,
**english_models,
}
@lru_cache(maxsize=32)
def get_speaker_map(repo_id: str) -> dict:
clean_id = repo_id.split("|")[0]
if "piper" in clean_id:
try:
local_dir = snapshot_download(clean_id)
for fn in os.listdir(local_dir):
if fn.endswith(".onnx.json"):
with open(os.path.join(local_dir, fn)) as f:
data = json.load(f)
sid_map = data.get("speaker_id_map", {})
return {v: k for k, v in sid_map.items()}
except Exception:
pass
if "vits-zh-hf-fanchen-C" in clean_id:
try:
f = hf_hub_download(repo_id=clean_id, filename="G_C.json")
with open(f) as fp:
data = json.load(fp)
speakers = data.get("speakers", [])
return {i: s for i, s in enumerate(speakers)}
except Exception:
pass
if "vits-zh-aishell3" in clean_id:
return {}
if "kokoro" in clean_id:
kokoro_en = [
(0, "美式女声-af"), (1, "美式女声-bella"),
(2, "美式女声-nicole"), (3, "美式女声-sarah"),
(4, "美式女声-sky"), (5, "美式男声-adam"),
(6, "美式男声-michael"), (7, "英式女声-emma"),
(8, "英式女声-isabella"), (9, "英式男声-george"),
(10, "英式男声-lewis"),
]
kokoro_multi_v10 = [
(0, "美式女声-alloy"), (1, "美式女声-aoede"),
(2, "美式女声-bella"), (3, "美式女声-heart"),
(4, "美式女声-jessica"), (5, "美式女声-kore"),
(6, "美式女声-nicole"), (7, "美式女声-nova"),
(8, "美式女声-river"), (9, "美式女声-sarah"),
(10, "美式女声-sky"), (11, "美式男声-adam"),
(12, "美式男声-echo"), (13, "美式男声-eric"),
(14, "美式男声-fenrir"), (15, "美式男声-liam"),
(16, "美式男声-michael"), (17, "美式男声-onyx"),
(18, "美式男声-puck"), (19, "美式男声-santa"),
(20, "英式女声-alice"), (21, "英式女声-emma"),
(22, "英式女声-isabella"), (23, "英式女声-lily"),
(24, "英式男声-daniel"), (25, "英式男声-fable"),
(26, "英式男声-george"), (27, "英式男声-lewis"),
(28, "英语女声-dora"), (29, "英语男声-alex"),
(30, "法语女声-siwis"), (31, "印地语女声-alpha"),
(32, "印地语女声-beta"), (33, "印地语男声-omega"),
(34, "印地语男声-psi"), (35, "意语女声-sara"),
(36, "意语男声-nicola"), (37, "日语女声-alpha"),
(38, "日语女声-gongitsune"), (39, "日语女声-nezumi"),
(40, "日语女声-tebukuro"), (41, "日语男声-kumo"),
(42, "葡语女声-dora"), (43, "葡语男声-alex"),
(44, "葡语男声-santa"), (45, "中文女声-小北"),
(46, "中文女声-小妮"), (47, "中文女声-小小"),
(48, "中文女声-小艺"), (49, "中文男声-云剑"),
(50, "中文男声-云希"), (51, "中文男声-云夏"),
(52, "中文男声-云扬"),
]
kokoro_multi_v11 = [
(0, "美式女声-maple"), (1, "美式女声-sol"),
(2, "英式女声-vale"),
(3, "中文女声-001"), (4, "中文女声-002"),
(5, "中文女声-003"), (6, "中文女声-004"),
(7, "中文女声-005"), (8, "中文女声-006"),
(9, "中文女声-007"), (10, "中文女声-008"),
(11, "中文女声-017"), (12, "中文女声-018"),
(13, "中文女声-019"), (14, "中文女声-021"),
(15, "中文女声-022"), (16, "中文女声-023"),
(17, "中文女声-024"), (18, "中文女声-026"),
(19, "中文女声-027"), (20, "中文女声-028"),
(21, "中文女声-032"), (22, "中文女声-036"),
(23, "中文女声-038"), (24, "中文女声-039"),
(25, "中文女声-040"), (26, "中文女声-042"),
(27, "中文女声-043"), (28, "中文女声-044"),
(29, "中文女声-046"), (30, "中文女声-047"),
(31, "中文女声-048"), (32, "中文女声-049"),
(33, "中文女声-051"), (34, "中文女声-059"),
(35, "中文女声-060"), (36, "中文女声-067"),
(37, "中文女声-070"), (38, "中文女声-071"),
(39, "中文女声-072"), (40, "中文女声-073"),
(41, "中文女声-074"), (42, "中文女声-075"),
(43, "中文女声-076"), (44, "中文女声-077"),
(45, "中文女声-078"), (46, "中文女声-079"),
(47, "中文女声-083"), (48, "中文女声-084"),
(49, "中文女声-085"), (50, "中文女声-086"),
(51, "中文女声-087"), (52, "中文女声-088"),
(53, "中文女声-090"), (54, "中文女声-092"),
(55, "中文女声-093"), (56, "中文女声-094"),
(57, "中文女声-099"),
(58, "中文男声-009"), (59, "中文男声-010"),
(60, "中文男声-011"), (61, "中文男声-012"),
(62, "中文男声-013"), (63, "中文男声-014"),
(64, "中文男声-015"), (65, "中文男声-016"),
(66, "中文男声-020"), (67, "中文男声-025"),
(68, "中文男声-029"), (69, "中文男声-030"),
(70, "中文男声-031"), (71, "中文男声-033"),
(72, "中文男声-034"), (73, "中文男声-035"),
(74, "中文男声-037"), (75, "中文男声-041"),
(76, "中文男声-045"), (77, "中文男声-050"),
(78, "中文男声-052"), (79, "中文男声-053"),
(80, "中文男声-054"), (81, "中文男声-055"),
(82, "中文男声-056"), (83, "中文男声-057"),
(84, "中文男声-058"), (85, "中文男声-061"),
(86, "中文男声-062"), (87, "中文男声-063"),
(88, "中文男声-064"), (89, "中文男声-065"),
(90, "中文男声-066"), (91, "中文男声-068"),
(92, "中文男声-069"), (93, "中文男声-080"),
(94, "中文男声-081"), (95, "中文男声-082"),
(96, "中文男声-089"), (97, "中文男声-091"),
(98, "中文男声-095"), (99, "中文男声-096"),
(100, "中文男声-097"), (101, "中文男声-098"),
(102, "中文男声-100"),
]
kokoro_maps = {
"csukuangfj/kokoro-en-v0_19": kokoro_en,
"csukuangfj/kokoro-multi-lang-v1_1": kokoro_multi_v11,
"csukuangfj/kokoro-multi-lang-v1_0": kokoro_multi_v10,
"csukuangfj/kokoro-int8-multi-lang-v1_1": kokoro_multi_v11,
"csukuangfj/kokoro-int8-multi-lang-v1_0": kokoro_multi_v10,
}
if clean_id in kokoro_maps:
return {sid: name for sid, name in kokoro_maps[clean_id]}
if "supertonic" in clean_id:
return {
0: "男声-M1(活泼自信)",
1: "男声-M2(深沉稳重)",
2: "男声-M3(专业权威)",
3: "男声-M4(柔和亲切)",
4: "男声-M5(温暖舒缓)",
5: "女声-F1(沉稳从容)",
6: "女声-F2(明快活泼)",
7: "女声-F3(清晰专业)",
8: "女声-F4(干练自信)",
9: "女声-F5(温柔平和)",
}
if "vits-zh-ll" in clean_id:
return {0: "女声0", 1: "女声1", 2: "男声0", 3: "男声1", 4: "男声2"}
return {}
@lru_cache(maxsize=10)
def get_pretrained_model(repo_id: str, speed: float) -> sherpa_onnx.OfflineTts:
if repo_id in all_model_dicts:
return all_model_dicts[repo_id](repo_id, speed)
raise ValueError(f"不支持的模型: {repo_id}")