ZipVoice.AXERA / scripts /local_tokenizer.py
HY-2012's picture
First commit
ea47387 verified
Raw
History Blame Contribute Delete
7.45 kB
from __future__ import annotations
import logging
import re
from functools import reduce
from pathlib import Path
from typing import Dict, List
import jieba
from pypinyin import Style, lazy_pinyin
from pypinyin.contrib.tone_convert import to_finals_tone3, to_initials
try:
from piper_phonemize import phonemize_espeak
except Exception as ex: # pragma: no cover - board dependency check
raise RuntimeError(
f"{ex}\nPlease install piper_phonemize for English tokenization."
)
jieba.default_logger.setLevel(logging.INFO)
class EnglishTextNormalizer:
def normalize(self, text: str) -> str:
return text
class ChineseTextNormalizer:
def normalize(self, text: str) -> str:
try:
import cn2an
return cn2an.transform(text, "an2cn")
except Exception:
return text
class LocalEmiliaTokenizer:
"""Small board-side Emilia tokenizer without lhotse/CutSet dependencies."""
def __init__(self, token_file: str | Path):
self.english_normalizer = EnglishTextNormalizer()
self.chinese_normalizer = ChineseTextNormalizer()
self.token2id: Dict[str, int] = {}
with open(token_file, "r", encoding="utf-8") as f:
for line in f:
token, token_id = line.rstrip().split("\t")
self.token2id[token] = int(token_id)
self.pad_id = self.token2id["_"]
self.vocab_size = len(self.token2id)
def texts_to_token_ids(self, texts: List[str]) -> List[List[int]]:
return self.tokens_to_token_ids(self.texts_to_tokens(texts))
def texts_to_tokens(self, texts: List[str]) -> List[List[str]]:
phoneme_list = []
for text in texts:
text = self.map_punctuations(text)
segments = self.get_segment(text)
all_phoneme = []
for seg_text, seg_type in segments:
if seg_type == "zh":
all_phoneme += self.tokenize_zh(seg_text)
elif seg_type == "en":
all_phoneme += self.tokenize_en(seg_text)
elif seg_type == "pinyin":
all_phoneme += self.tokenize_pinyin(seg_text)
elif seg_type == "tag":
all_phoneme.append(seg_text)
else:
logging.debug("Skipping unknown language segment: %r", (seg_text, seg_type))
phoneme_list.append(all_phoneme)
return phoneme_list
def tokens_to_token_ids(self, tokens_list: List[List[str]]) -> List[List[int]]:
token_ids_list = []
for tokens in tokens_list:
token_ids = []
for token in tokens:
if token not in self.token2id:
logging.debug("Skip OOV token %s", token)
continue
token_ids.append(self.token2id[token])
token_ids_list.append(token_ids)
return token_ids_list
def tokenize_zh(self, text: str) -> List[str]:
try:
text = self.chinese_normalizer.normalize(text)
segs = list(jieba.cut(text))
full = lazy_pinyin(
segs,
style=Style.TONE3,
tone_sandhi=True,
neutral_tone_with_five=True,
)
phones = []
for item in full:
if not (item[0:-1].isalpha() and item[-1] in ("1", "2", "3", "4", "5")):
phones.append(item)
else:
phones.extend(self.separate_pinyin(item))
return phones
except Exception as ex:
logging.debug("Tokenization of Chinese text failed: %s", ex)
return []
def tokenize_en(self, text: str) -> List[str]:
try:
text = self.english_normalizer.normalize(text)
tokens = phonemize_espeak(text, "en-us")
return reduce(lambda x, y: x + y, tokens)
except Exception as ex:
logging.debug("Tokenization of English text failed: %s", ex)
return []
def tokenize_pinyin(self, text: str) -> List[str]:
try:
text = text.lstrip("<").rstrip(">")
if not (text[0:-1].isalpha() and text[-1] in ("1", "2", "3", "4", "5")):
logging.debug("Invalid pinyin token: %s", text)
return []
return self.separate_pinyin(text)
except Exception as ex:
logging.debug("Tokenize pinyin failed: %s", ex)
return []
@staticmethod
def separate_pinyin(text: str) -> List[str]:
pinyins = []
initial = to_initials(text, strict=False)
final = to_finals_tone3(text, strict=False, neutral_tone_with_five=True)
if initial:
pinyins.append(initial + "0")
if final:
pinyins.append(final)
return pinyins
@staticmethod
def map_punctuations(text: str) -> str:
replacements = {
",": ",",
"。": ".",
"!": "!",
"?": "?",
";": ";",
":": ":",
"、": ",",
"‘": "'",
"“": '"',
"”": '"',
"’": "'",
"⋯": "…",
"···": "…",
"・・・": "…",
"...": "…",
}
for src, dst in replacements.items():
text = text.replace(src, dst)
return text
def get_segment(self, text: str):
segments = []
types = []
temp_seg = ""
temp_lang = ""
parts = re.compile(r"[<[].*?[>\]]|.").findall(text)
for part in parts:
if self.is_chinese(part) or self.is_pinyin(part):
types.append("zh")
elif self.is_alphabet(part):
types.append("en")
else:
types.append("other")
for index, part_type in enumerate(types):
if index == 0:
temp_seg += parts[index]
temp_lang = part_type
elif temp_lang == "other":
temp_seg += parts[index]
temp_lang = part_type
elif part_type in [temp_lang, "other"]:
temp_seg += parts[index]
else:
segments.append((temp_seg, temp_lang))
temp_seg = parts[index]
temp_lang = part_type
segments.append((temp_seg, temp_lang))
return self.split_segments(segments)
@staticmethod
def split_segments(segments):
result = []
for temp_seg, temp_lang in segments:
parts = re.split(r"([<[].*?[>\]])", temp_seg)
for part in parts:
if not part:
continue
if part.startswith("<") and part.endswith(">"):
result.append((part, "pinyin"))
elif part.startswith("[") and part.endswith("]"):
result.append((part, "tag"))
else:
result.append((part, temp_lang))
return result
@staticmethod
def is_chinese(char: str) -> bool:
return "\u4e00" <= char <= "\u9fa5"
@staticmethod
def is_alphabet(char: str) -> bool:
return ("A" <= char <= "Z") or ("a" <= char <= "z")
@staticmethod
def is_pinyin(part: str) -> bool:
return part.startswith("<") and part.endswith(">")