| from __future__ import annotations |
|
|
| import logging |
| import re |
| from functools import reduce |
| from pathlib import Path |
| from typing import Dict, List |
|
|
| import jieba |
| from pypinyin import Style, lazy_pinyin |
| from pypinyin.contrib.tone_convert import to_finals_tone3, to_initials |
|
|
| try: |
| from piper_phonemize import phonemize_espeak |
| except Exception as ex: |
| raise RuntimeError( |
| f"{ex}\nPlease install piper_phonemize for English tokenization." |
| ) |
|
|
| jieba.default_logger.setLevel(logging.INFO) |
|
|
|
|
| class EnglishTextNormalizer: |
| def normalize(self, text: str) -> str: |
| return text |
|
|
|
|
| class ChineseTextNormalizer: |
| def normalize(self, text: str) -> str: |
| try: |
| import cn2an |
|
|
| return cn2an.transform(text, "an2cn") |
| except Exception: |
| return text |
|
|
|
|
| class LocalEmiliaTokenizer: |
| """Small board-side Emilia tokenizer without lhotse/CutSet dependencies.""" |
|
|
| def __init__(self, token_file: str | Path): |
| self.english_normalizer = EnglishTextNormalizer() |
| self.chinese_normalizer = ChineseTextNormalizer() |
| self.token2id: Dict[str, int] = {} |
| with open(token_file, "r", encoding="utf-8") as f: |
| for line in f: |
| token, token_id = line.rstrip().split("\t") |
| self.token2id[token] = int(token_id) |
| self.pad_id = self.token2id["_"] |
| self.vocab_size = len(self.token2id) |
|
|
| def texts_to_token_ids(self, texts: List[str]) -> List[List[int]]: |
| return self.tokens_to_token_ids(self.texts_to_tokens(texts)) |
|
|
| def texts_to_tokens(self, texts: List[str]) -> List[List[str]]: |
| phoneme_list = [] |
| for text in texts: |
| text = self.map_punctuations(text) |
| segments = self.get_segment(text) |
| all_phoneme = [] |
| for seg_text, seg_type in segments: |
| if seg_type == "zh": |
| all_phoneme += self.tokenize_zh(seg_text) |
| elif seg_type == "en": |
| all_phoneme += self.tokenize_en(seg_text) |
| elif seg_type == "pinyin": |
| all_phoneme += self.tokenize_pinyin(seg_text) |
| elif seg_type == "tag": |
| all_phoneme.append(seg_text) |
| else: |
| logging.debug("Skipping unknown language segment: %r", (seg_text, seg_type)) |
| phoneme_list.append(all_phoneme) |
| return phoneme_list |
|
|
| def tokens_to_token_ids(self, tokens_list: List[List[str]]) -> List[List[int]]: |
| token_ids_list = [] |
| for tokens in tokens_list: |
| token_ids = [] |
| for token in tokens: |
| if token not in self.token2id: |
| logging.debug("Skip OOV token %s", token) |
| continue |
| token_ids.append(self.token2id[token]) |
| token_ids_list.append(token_ids) |
| return token_ids_list |
|
|
| def tokenize_zh(self, text: str) -> List[str]: |
| try: |
| text = self.chinese_normalizer.normalize(text) |
| segs = list(jieba.cut(text)) |
| full = lazy_pinyin( |
| segs, |
| style=Style.TONE3, |
| tone_sandhi=True, |
| neutral_tone_with_five=True, |
| ) |
| phones = [] |
| for item in full: |
| if not (item[0:-1].isalpha() and item[-1] in ("1", "2", "3", "4", "5")): |
| phones.append(item) |
| else: |
| phones.extend(self.separate_pinyin(item)) |
| return phones |
| except Exception as ex: |
| logging.debug("Tokenization of Chinese text failed: %s", ex) |
| return [] |
|
|
| def tokenize_en(self, text: str) -> List[str]: |
| try: |
| text = self.english_normalizer.normalize(text) |
| tokens = phonemize_espeak(text, "en-us") |
| return reduce(lambda x, y: x + y, tokens) |
| except Exception as ex: |
| logging.debug("Tokenization of English text failed: %s", ex) |
| return [] |
|
|
| def tokenize_pinyin(self, text: str) -> List[str]: |
| try: |
| text = text.lstrip("<").rstrip(">") |
| if not (text[0:-1].isalpha() and text[-1] in ("1", "2", "3", "4", "5")): |
| logging.debug("Invalid pinyin token: %s", text) |
| return [] |
| return self.separate_pinyin(text) |
| except Exception as ex: |
| logging.debug("Tokenize pinyin failed: %s", ex) |
| return [] |
|
|
| @staticmethod |
| def separate_pinyin(text: str) -> List[str]: |
| pinyins = [] |
| initial = to_initials(text, strict=False) |
| final = to_finals_tone3(text, strict=False, neutral_tone_with_five=True) |
| if initial: |
| pinyins.append(initial + "0") |
| if final: |
| pinyins.append(final) |
| return pinyins |
|
|
| @staticmethod |
| def map_punctuations(text: str) -> str: |
| replacements = { |
| ",": ",", |
| "。": ".", |
| "!": "!", |
| "?": "?", |
| ";": ";", |
| ":": ":", |
| "、": ",", |
| "‘": "'", |
| "“": '"', |
| "”": '"', |
| "’": "'", |
| "⋯": "…", |
| "···": "…", |
| "・・・": "…", |
| "...": "…", |
| } |
| for src, dst in replacements.items(): |
| text = text.replace(src, dst) |
| return text |
|
|
| def get_segment(self, text: str): |
| segments = [] |
| types = [] |
| temp_seg = "" |
| temp_lang = "" |
| parts = re.compile(r"[<[].*?[>\]]|.").findall(text) |
| for part in parts: |
| if self.is_chinese(part) or self.is_pinyin(part): |
| types.append("zh") |
| elif self.is_alphabet(part): |
| types.append("en") |
| else: |
| types.append("other") |
|
|
| for index, part_type in enumerate(types): |
| if index == 0: |
| temp_seg += parts[index] |
| temp_lang = part_type |
| elif temp_lang == "other": |
| temp_seg += parts[index] |
| temp_lang = part_type |
| elif part_type in [temp_lang, "other"]: |
| temp_seg += parts[index] |
| else: |
| segments.append((temp_seg, temp_lang)) |
| temp_seg = parts[index] |
| temp_lang = part_type |
| segments.append((temp_seg, temp_lang)) |
| return self.split_segments(segments) |
|
|
| @staticmethod |
| def split_segments(segments): |
| result = [] |
| for temp_seg, temp_lang in segments: |
| parts = re.split(r"([<[].*?[>\]])", temp_seg) |
| for part in parts: |
| if not part: |
| continue |
| if part.startswith("<") and part.endswith(">"): |
| result.append((part, "pinyin")) |
| elif part.startswith("[") and part.endswith("]"): |
| result.append((part, "tag")) |
| else: |
| result.append((part, temp_lang)) |
| return result |
|
|
| @staticmethod |
| def is_chinese(char: str) -> bool: |
| return "\u4e00" <= char <= "\u9fa5" |
|
|
| @staticmethod |
| def is_alphabet(char: str) -> bool: |
| return ("A" <= char <= "Z") or ("a" <= char <= "z") |
|
|
| @staticmethod |
| def is_pinyin(part: str) -> bool: |
| return part.startswith("<") and part.endswith(">") |
|
|