File size: 2,926 Bytes
32cafd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import re
from typing import List, Optional, Sequence, Union, Tuple
from transformers.tokenization_utils import PreTrainedTokenizer
import sentencepiece as spm

class SparkTokenizer(PreTrainedTokenizer):
    vocab_files_names = {"vocab_file": "tokenizer.model"}
    model_input_names = ["input_ids", "attention_mask"]

    def __init__(
        self,
        vocab_file,
        clean_up_tokenization_spaces=False,
        split=True,
        **kwargs
    ):
        self.vocab_file = vocab_file
        self.split = split
        
        # Load SentencePiece model
        self.sp = spm.SentencePieceProcessor(model_file=vocab_file)
        
        # Build encoder/decoder from sp model for compatibility
        self.encoder = {}
        self.decoder = {}
        for i in range(self.sp.get_piece_size()):
            piece = self.sp.id_to_piece(i)
            self.encoder[piece] = i
            self.decoder[i] = piece

        super().__init__(
            clean_up_tokenization_spaces=clean_up_tokenization_spaces,
            **kwargs
        )
        
        # Standard special tokens
        self.sep_id = self.encoder.get('<s>', None)
        self.eod_id = self.encoder.get('<end>', None)
        self.pad_id = self.encoder.get('<pad>', 0)
        self.unk_id = self.encoder.get('<unk>', None)

    @property
    def vocab_size(self) -> int:
        return self.sp.get_piece_size()

    def get_vocab(self):
        return self.encoder

    def _tokenize(self, text: str) -> List[str]:
        # --- Megatron 兼容预处理 ---
        text = re.sub("(,|。|!|?) *", r"\1 ", text)
        text = text.replace("\n", "<ret>")
        text = text.replace("\t", " " * 4)
        
        if self.split:
            # Custom splitting logic for special tokens
            text_list = re.split(r'(<ret>|<end>|<s>)', text)
            pieces = []
            for each in text_list:
                if each in ['<ret>', '<end>', '<s>']:
                    pieces.append(each)
                else:
                    pieces.extend(self.sp.encode_as_pieces(each))
            return pieces
        return self.sp.encode_as_pieces(text)

    def _convert_token_to_id(self, token):
        return self.encoder.get(token, self.unk_id)

    def _convert_id_to_token(self, index):
        return self.decoder.get(index, "<unk>")

    def convert_tokens_to_string(self, tokens: List[str]) -> str:
        return self.sp.decode_pieces(tokens)

    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        if not os.path.isdir(save_directory):
            os.makedirs(save_directory)
        
        vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "tokenizer.model")
        
        with open(vocab_file, "wb") as f:
            f.write(self.sp.serialized_model_proto())
        
        return (vocab_file,)