import os import re from typing import List, Optional, Sequence, Union, Tuple from transformers.tokenization_utils import PreTrainedTokenizer import sentencepiece as spm class SparkTokenizer(PreTrainedTokenizer): vocab_files_names = {"vocab_file": "tokenizer.model"} model_input_names = ["input_ids", "attention_mask"] def __init__( self, vocab_file, clean_up_tokenization_spaces=False, split=True, **kwargs ): self.vocab_file = vocab_file self.split = split # Load SentencePiece model self.sp = spm.SentencePieceProcessor(model_file=vocab_file) # Build encoder/decoder from sp model for compatibility self.encoder = {} self.decoder = {} for i in range(self.sp.get_piece_size()): piece = self.sp.id_to_piece(i) self.encoder[piece] = i self.decoder[i] = piece super().__init__( clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs ) # Standard special tokens self.sep_id = self.encoder.get('', None) self.eod_id = self.encoder.get('', None) self.pad_id = self.encoder.get('', 0) self.unk_id = self.encoder.get('', None) @property def vocab_size(self) -> int: return self.sp.get_piece_size() def get_vocab(self): return self.encoder def _tokenize(self, text: str) -> List[str]: # --- Megatron 兼容预处理 --- text = re.sub("(,|。|!|?) *", r"\1 ", text) text = text.replace("\n", "") text = text.replace("\t", " " * 4) if self.split: # Custom splitting logic for special tokens text_list = re.split(r'(||)', text) pieces = [] for each in text_list: if each in ['', '', '']: pieces.append(each) else: pieces.extend(self.sp.encode_as_pieces(each)) return pieces return self.sp.encode_as_pieces(text) def _convert_token_to_id(self, token): return self.encoder.get(token, self.unk_id) def _convert_id_to_token(self, index): return self.decoder.get(index, "") def convert_tokens_to_string(self, tokens: List[str]) -> str: return self.sp.decode_pieces(tokens) def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: if not os.path.isdir(save_directory): os.makedirs(save_directory) vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "tokenizer.model") with open(vocab_file, "wb") as f: f.write(self.sp.serialized_model_proto()) return (vocab_file,)