File size: 7,451 Bytes
ea47387
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
from __future__ import annotations

import logging
import re
from functools import reduce
from pathlib import Path
from typing import Dict, List

import jieba
from pypinyin import Style, lazy_pinyin
from pypinyin.contrib.tone_convert import to_finals_tone3, to_initials

try:
    from piper_phonemize import phonemize_espeak
except Exception as ex:  # pragma: no cover - board dependency check
    raise RuntimeError(
        f"{ex}\nPlease install piper_phonemize for English tokenization."
    )

jieba.default_logger.setLevel(logging.INFO)


class EnglishTextNormalizer:
    def normalize(self, text: str) -> str:
        return text


class ChineseTextNormalizer:
    def normalize(self, text: str) -> str:
        try:
            import cn2an

            return cn2an.transform(text, "an2cn")
        except Exception:
            return text


class LocalEmiliaTokenizer:
    """Small board-side Emilia tokenizer without lhotse/CutSet dependencies."""

    def __init__(self, token_file: str | Path):
        self.english_normalizer = EnglishTextNormalizer()
        self.chinese_normalizer = ChineseTextNormalizer()
        self.token2id: Dict[str, int] = {}
        with open(token_file, "r", encoding="utf-8") as f:
            for line in f:
                token, token_id = line.rstrip().split("\t")
                self.token2id[token] = int(token_id)
        self.pad_id = self.token2id["_"]
        self.vocab_size = len(self.token2id)

    def texts_to_token_ids(self, texts: List[str]) -> List[List[int]]:
        return self.tokens_to_token_ids(self.texts_to_tokens(texts))

    def texts_to_tokens(self, texts: List[str]) -> List[List[str]]:
        phoneme_list = []
        for text in texts:
            text = self.map_punctuations(text)
            segments = self.get_segment(text)
            all_phoneme = []
            for seg_text, seg_type in segments:
                if seg_type == "zh":
                    all_phoneme += self.tokenize_zh(seg_text)
                elif seg_type == "en":
                    all_phoneme += self.tokenize_en(seg_text)
                elif seg_type == "pinyin":
                    all_phoneme += self.tokenize_pinyin(seg_text)
                elif seg_type == "tag":
                    all_phoneme.append(seg_text)
                else:
                    logging.debug("Skipping unknown language segment: %r", (seg_text, seg_type))
            phoneme_list.append(all_phoneme)
        return phoneme_list

    def tokens_to_token_ids(self, tokens_list: List[List[str]]) -> List[List[int]]:
        token_ids_list = []
        for tokens in tokens_list:
            token_ids = []
            for token in tokens:
                if token not in self.token2id:
                    logging.debug("Skip OOV token %s", token)
                    continue
                token_ids.append(self.token2id[token])
            token_ids_list.append(token_ids)
        return token_ids_list

    def tokenize_zh(self, text: str) -> List[str]:
        try:
            text = self.chinese_normalizer.normalize(text)
            segs = list(jieba.cut(text))
            full = lazy_pinyin(
                segs,
                style=Style.TONE3,
                tone_sandhi=True,
                neutral_tone_with_five=True,
            )
            phones = []
            for item in full:
                if not (item[0:-1].isalpha() and item[-1] in ("1", "2", "3", "4", "5")):
                    phones.append(item)
                else:
                    phones.extend(self.separate_pinyin(item))
            return phones
        except Exception as ex:
            logging.debug("Tokenization of Chinese text failed: %s", ex)
            return []

    def tokenize_en(self, text: str) -> List[str]:
        try:
            text = self.english_normalizer.normalize(text)
            tokens = phonemize_espeak(text, "en-us")
            return reduce(lambda x, y: x + y, tokens)
        except Exception as ex:
            logging.debug("Tokenization of English text failed: %s", ex)
            return []

    def tokenize_pinyin(self, text: str) -> List[str]:
        try:
            text = text.lstrip("<").rstrip(">")
            if not (text[0:-1].isalpha() and text[-1] in ("1", "2", "3", "4", "5")):
                logging.debug("Invalid pinyin token: %s", text)
                return []
            return self.separate_pinyin(text)
        except Exception as ex:
            logging.debug("Tokenize pinyin failed: %s", ex)
            return []

    @staticmethod
    def separate_pinyin(text: str) -> List[str]:
        pinyins = []
        initial = to_initials(text, strict=False)
        final = to_finals_tone3(text, strict=False, neutral_tone_with_five=True)
        if initial:
            pinyins.append(initial + "0")
        if final:
            pinyins.append(final)
        return pinyins

    @staticmethod
    def map_punctuations(text: str) -> str:
        replacements = {
            ",": ",",
            "。": ".",
            "!": "!",
            "?": "?",
            ";": ";",
            ":": ":",
            "、": ",",
            "‘": "'",
            "“": '"',
            "”": '"',
            "’": "'",
            "⋯": "…",
            "···": "…",
            "・・・": "…",
            "...": "…",
        }
        for src, dst in replacements.items():
            text = text.replace(src, dst)
        return text

    def get_segment(self, text: str):
        segments = []
        types = []
        temp_seg = ""
        temp_lang = ""
        parts = re.compile(r"[<[].*?[>\]]|.").findall(text)
        for part in parts:
            if self.is_chinese(part) or self.is_pinyin(part):
                types.append("zh")
            elif self.is_alphabet(part):
                types.append("en")
            else:
                types.append("other")

        for index, part_type in enumerate(types):
            if index == 0:
                temp_seg += parts[index]
                temp_lang = part_type
            elif temp_lang == "other":
                temp_seg += parts[index]
                temp_lang = part_type
            elif part_type in [temp_lang, "other"]:
                temp_seg += parts[index]
            else:
                segments.append((temp_seg, temp_lang))
                temp_seg = parts[index]
                temp_lang = part_type
        segments.append((temp_seg, temp_lang))
        return self.split_segments(segments)

    @staticmethod
    def split_segments(segments):
        result = []
        for temp_seg, temp_lang in segments:
            parts = re.split(r"([<[].*?[>\]])", temp_seg)
            for part in parts:
                if not part:
                    continue
                if part.startswith("<") and part.endswith(">"):
                    result.append((part, "pinyin"))
                elif part.startswith("[") and part.endswith("]"):
                    result.append((part, "tag"))
                else:
                    result.append((part, temp_lang))
        return result

    @staticmethod
    def is_chinese(char: str) -> bool:
        return "\u4e00" <= char <= "\u9fa5"

    @staticmethod
    def is_alphabet(char: str) -> bool:
        return ("A" <= char <= "Z") or ("a" <= char <= "z")

    @staticmethod
    def is_pinyin(part: str) -> bool:
        return part.startswith("<") and part.endswith(">")