"""SentencePiece tokenizer wrapper for pinyin-code Transformers models.""" from __future__ import annotations import logging import re import shutil from pathlib import Path import sentencepiece as spm from transformers import PreTrainedTokenizer CHINESE_RE = re.compile(r"[\u3400-\u4dbf\u4e00-\u9fff]") CHINESE_SPAN_RE = re.compile(r"[\u3400-\u4dbf\u4e00-\u9fff]+") PUNCTUATION = set("\u3002\uff0c\u3001\uff1f\uff01\uff1a\uff1b.,?!:;()[]\u201c\u201d\"'") TOKEN_RE = re.compile( r"<[A-Z_]+>|" r"[\u3400-\u4dbf\u4e00-\u9fff]+|" r"[A-Za-z]+|" r"[\u3002\uff0c\u3001\uff1f\uff01\uff1a\uff1b.,?!:;()\[\]\u201c\u201d\"']|" r"\S" ) LABELS = { "\u9898\u5e72": "", "\u9009\u9879": "", "\u7b54\u6848": "", "\u89e3\u6790": "", } PINYIN_FORMAT_ALIASES = { "code": "pinyin-code", "codes": "pinyin-code", "pinyin-code": "pinyin-code", "initial": "pinyin-initial", "initials": "pinyin-initial", "pinyin-initial": "pinyin-initial", "hanzi": "hanzi", } class PinyinCodeTokenizer(PreTrainedTokenizer): """Slow tokenizer that preserves the existing SentencePiece model.""" vocab_files_names = {"vocab_file": "tokenizer.model"} model_input_names = ["input_ids", "attention_mask"] def __init__( self, vocab_file: str, add_bos_token: bool = False, add_eos_token: bool = False, transliteration: str = "pinyin-code", pinyin_format: str | None = None, use_jieba: bool = True, jieba: bool | None = None, **kwargs, ) -> None: self.vocab_file = vocab_file self.sp_model = spm.SentencePieceProcessor(model_file=vocab_file) self.add_bos_token = add_bos_token self.add_eos_token = add_eos_token self.transliteration = self._normalize_transliteration( pinyin_format or transliteration ) self.use_jieba = use_jieba if jieba is None else jieba kwargs.setdefault("unk_token", self._piece_or_none(self.sp_model.unk_id())) kwargs.setdefault("bos_token", self._piece_or_none(self.sp_model.bos_id())) kwargs.setdefault("eos_token", self._piece_or_none(self.sp_model.eos_id())) kwargs.setdefault("pad_token", self._piece_or_none(self.sp_model.pad_id())) kwargs.setdefault("transliteration", self.transliteration) kwargs.setdefault("pinyin_format", self.transliteration) kwargs.setdefault("use_jieba", self.use_jieba) kwargs.setdefault("jieba", self.use_jieba) super().__init__(**kwargs) def _normalize_transliteration(self, value: str) -> str: normalized = PINYIN_FORMAT_ALIASES.get(value.lower()) if normalized is None: allowed = ", ".join(sorted(set(PINYIN_FORMAT_ALIASES.values()))) raise ValueError(f"Unsupported transliteration {value!r}; choose from {allowed}") return normalized def _piece_or_none(self, token_id: int) -> str | None: if token_id is None or token_id < 0: return None return self.sp_model.id_to_piece(token_id) def _preprocess_raw_text(self, text: str) -> str: if not CHINESE_RE.search(text): return text try: from preprocessing.preprocess import process_text, require_dependencies except ImportError: return self._fallback_process_text(text) require_dependencies() return process_text(text, self.transliteration, self.use_jieba) def _fallback_process_text(self, text: str) -> str: if self.use_jieba: try: import jieba except ImportError as exc: raise ImportError( "Tokenizing raw Mandarin benchmark text with jieba segmentation " "requires jieba. Install the model dependencies before running " "lm_eval." ) from exc jieba.setLogLevel(logging.WARNING) else: jieba = None if self.transliteration != "hanzi": try: from pypinyin import Style, pinyin except ImportError as exc: raise ImportError( "Tokenizing raw Mandarin benchmark text as pinyin requires pypinyin. " "Install the model dependencies before running lm_eval." ) from exc def normalize_text(value: str) -> str: value = re.sub(r"\$\$.*?\$\$", " ", value, flags=re.DOTALL) value = re.sub(r"[\uff08(]\s*[\uff09)]", " ", value) for label, marker in LABELS.items(): value = re.sub(rf"{label}\s*[:\uff1a]", f" {marker} ", value) value = re.sub( r"(? ", value, flags=re.I ) value = re.sub( r"(? ", value, flags=re.I ) value = re.sub( r"(? ", value, ) value = re.sub( r"(? ", value, ) value = value.replace("\uff08", "(").replace("\uff09", ")") return re.sub(r"\s+", " ", value).strip() def split_tone3_syllable(syllable: str) -> tuple[str, int]: match = re.fullmatch(r"([a-z\u00fcv]+)([1-5]?)", syllable.lower()) if not match: return syllable, 5 plain, tone = match.groups() return plain, int(tone or "5") def length_digit_offset(syllable: str) -> int: return min(max(len(syllable), 1), 5) - 1 def syllable_to_initial_code(syllable: str) -> str: plain, tone = split_tone3_syllable(syllable) if not plain: return "" tone_offset = 5 if tone in {3, 4, 5} else 0 digit = tone_offset + length_digit_offset(plain) initial = plain[0].upper() if tone in {1, 3, 5} else plain[0].lower() return f"{initial}{digit}" def syllable_to_initial_letter(syllable: str) -> str: plain, _ = split_tone3_syllable(syllable) return plain[:1].lower() def convert_word(word: str) -> str: if self.transliteration == "hanzi": return word syllables = pinyin(word, style=Style.TONE3, heteronym=False, errors="ignore") if self.transliteration == "pinyin-code": codes = [ syllable_to_initial_code(item[0]) for item in syllables if item and item[0] ] return "".join(code for code in codes if code) initials = [ syllable_to_initial_letter(item[0]) for item in syllables if item and item[0] ] return "".join(initial for initial in initials if initial) def tokenize_chinese_span(value: str) -> list[str]: tokens = [] words = jieba.cut(value, cut_all=False) if self.use_jieba else value for word in words: word = word.strip() if word and CHINESE_SPAN_RE.search(word): token = convert_word(word) if token: tokens.append(token) return tokens tokens = [] for part in TOKEN_RE.findall(normalize_text(text)): if part.startswith("<") and part.endswith(">"): tokens.append(part) elif CHINESE_SPAN_RE.fullmatch(part): tokens.extend(tokenize_chinese_span(part)) elif part in PUNCTUATION: tokens.append(part) elif re.fullmatch(r"[A-Za-z]+", part): upper = part.upper() tokens.append(upper if upper in {"A", "B", "C", "D"} else part.lower()) return " ".join(tokens) @property def vocab_size(self) -> int: return self.sp_model.get_piece_size() def get_vocab(self) -> dict[str, int]: vocab = {self.sp_model.id_to_piece(i): i for i in range(self.vocab_size)} vocab.update(self.added_tokens_encoder) return vocab def _tokenize(self, text: str) -> list[str]: text = self._preprocess_raw_text(text) return self.sp_model.encode(text, out_type=str) def _convert_token_to_id(self, token: str) -> int: return self.sp_model.piece_to_id(token) def _convert_id_to_token(self, index: int) -> str: return self.sp_model.id_to_piece(index) def convert_tokens_to_string(self, tokens: list[str]) -> str: return self.sp_model.decode(tokens) def build_inputs_with_special_tokens( self, token_ids_0: list[int], token_ids_1: list[int] | None = None, ) -> list[int]: output = list(token_ids_0) if self.add_bos_token and self.bos_token_id is not None: output = [self.bos_token_id] + output if self.add_eos_token and self.eos_token_id is not None: output = output + [self.eos_token_id] if token_ids_1 is not None: output += list(token_ids_1) if self.add_eos_token and self.eos_token_id is not None: output.append(self.eos_token_id) return output def get_special_tokens_mask( self, token_ids_0: list[int], token_ids_1: list[int] | None = None, already_has_special_tokens: bool = False, ) -> list[int]: if already_has_special_tokens: special_ids = set(self.all_special_ids) return [1 if token_id in special_ids else 0 for token_id in token_ids_0] mask = [0] * len(token_ids_0) if self.add_bos_token and self.bos_token_id is not None: mask = [1] + mask if self.add_eos_token and self.eos_token_id is not None: mask = mask + [1] if token_ids_1 is not None: mask += [0] * len(token_ids_1) if self.add_eos_token and self.eos_token_id is not None: mask.append(1) return mask def create_token_type_ids_from_sequences( self, token_ids_0: list[int], token_ids_1: list[int] | None = None, ) -> list[int]: return [0] * len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1)) def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple[str]: output_name = "tokenizer.model" if filename_prefix: output_name = f"{filename_prefix}-{output_name}" output_path = Path(save_directory) / output_name if Path(self.vocab_file).resolve() != output_path.resolve(): shutil.copyfile(self.vocab_file, output_path) return (str(output_path),)