from __future__ import annotations import logging import re from functools import reduce from pathlib import Path from typing import Dict, List import jieba from pypinyin import Style, lazy_pinyin from pypinyin.contrib.tone_convert import to_finals_tone3, to_initials try: from piper_phonemize import phonemize_espeak except Exception as ex: # pragma: no cover - board dependency check raise RuntimeError( f"{ex}\nPlease install piper_phonemize for English tokenization." ) jieba.default_logger.setLevel(logging.INFO) class EnglishTextNormalizer: def normalize(self, text: str) -> str: return text class ChineseTextNormalizer: def normalize(self, text: str) -> str: try: import cn2an return cn2an.transform(text, "an2cn") except Exception: return text class LocalEmiliaTokenizer: """Small board-side Emilia tokenizer without lhotse/CutSet dependencies.""" def __init__(self, token_file: str | Path): self.english_normalizer = EnglishTextNormalizer() self.chinese_normalizer = ChineseTextNormalizer() self.token2id: Dict[str, int] = {} with open(token_file, "r", encoding="utf-8") as f: for line in f: token, token_id = line.rstrip().split("\t") self.token2id[token] = int(token_id) self.pad_id = self.token2id["_"] self.vocab_size = len(self.token2id) def texts_to_token_ids(self, texts: List[str]) -> List[List[int]]: return self.tokens_to_token_ids(self.texts_to_tokens(texts)) def texts_to_tokens(self, texts: List[str]) -> List[List[str]]: phoneme_list = [] for text in texts: text = self.map_punctuations(text) segments = self.get_segment(text) all_phoneme = [] for seg_text, seg_type in segments: if seg_type == "zh": all_phoneme += self.tokenize_zh(seg_text) elif seg_type == "en": all_phoneme += self.tokenize_en(seg_text) elif seg_type == "pinyin": all_phoneme += self.tokenize_pinyin(seg_text) elif seg_type == "tag": all_phoneme.append(seg_text) else: logging.debug("Skipping unknown language segment: %r", (seg_text, seg_type)) phoneme_list.append(all_phoneme) return phoneme_list def tokens_to_token_ids(self, tokens_list: List[List[str]]) -> List[List[int]]: token_ids_list = [] for tokens in tokens_list: token_ids = [] for token in tokens: if token not in self.token2id: logging.debug("Skip OOV token %s", token) continue token_ids.append(self.token2id[token]) token_ids_list.append(token_ids) return token_ids_list def tokenize_zh(self, text: str) -> List[str]: try: text = self.chinese_normalizer.normalize(text) segs = list(jieba.cut(text)) full = lazy_pinyin( segs, style=Style.TONE3, tone_sandhi=True, neutral_tone_with_five=True, ) phones = [] for item in full: if not (item[0:-1].isalpha() and item[-1] in ("1", "2", "3", "4", "5")): phones.append(item) else: phones.extend(self.separate_pinyin(item)) return phones except Exception as ex: logging.debug("Tokenization of Chinese text failed: %s", ex) return [] def tokenize_en(self, text: str) -> List[str]: try: text = self.english_normalizer.normalize(text) tokens = phonemize_espeak(text, "en-us") return reduce(lambda x, y: x + y, tokens) except Exception as ex: logging.debug("Tokenization of English text failed: %s", ex) return [] def tokenize_pinyin(self, text: str) -> List[str]: try: text = text.lstrip("<").rstrip(">") if not (text[0:-1].isalpha() and text[-1] in ("1", "2", "3", "4", "5")): logging.debug("Invalid pinyin token: %s", text) return [] return self.separate_pinyin(text) except Exception as ex: logging.debug("Tokenize pinyin failed: %s", ex) return [] @staticmethod def separate_pinyin(text: str) -> List[str]: pinyins = [] initial = to_initials(text, strict=False) final = to_finals_tone3(text, strict=False, neutral_tone_with_five=True) if initial: pinyins.append(initial + "0") if final: pinyins.append(final) return pinyins @staticmethod def map_punctuations(text: str) -> str: replacements = { ",": ",", "。": ".", "!": "!", "?": "?", ";": ";", ":": ":", "、": ",", "‘": "'", "“": '"', "”": '"', "’": "'", "⋯": "…", "···": "…", "・・・": "…", "...": "…", } for src, dst in replacements.items(): text = text.replace(src, dst) return text def get_segment(self, text: str): segments = [] types = [] temp_seg = "" temp_lang = "" parts = re.compile(r"[<[].*?[>\]]|.").findall(text) for part in parts: if self.is_chinese(part) or self.is_pinyin(part): types.append("zh") elif self.is_alphabet(part): types.append("en") else: types.append("other") for index, part_type in enumerate(types): if index == 0: temp_seg += parts[index] temp_lang = part_type elif temp_lang == "other": temp_seg += parts[index] temp_lang = part_type elif part_type in [temp_lang, "other"]: temp_seg += parts[index] else: segments.append((temp_seg, temp_lang)) temp_seg = parts[index] temp_lang = part_type segments.append((temp_seg, temp_lang)) return self.split_segments(segments) @staticmethod def split_segments(segments): result = [] for temp_seg, temp_lang in segments: parts = re.split(r"([<[].*?[>\]])", temp_seg) for part in parts: if not part: continue if part.startswith("<") and part.endswith(">"): result.append((part, "pinyin")) elif part.startswith("[") and part.endswith("]"): result.append((part, "tag")) else: result.append((part, temp_lang)) return result @staticmethod def is_chinese(char: str) -> bool: return "\u4e00" <= char <= "\u9fa5" @staticmethod def is_alphabet(char: str) -> bool: return ("A" <= char <= "Z") or ("a" <= char <= "z") @staticmethod def is_pinyin(part: str) -> bool: return part.startswith("<") and part.endswith(">")