File size: 7,956 Bytes
e10f35b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import regex

def get_stats(ids, counts=None):
    """
    统计一列整数中相邻两个数组成的数对出现的频率
    """
    if len(ids) <= 1:
        return counts
    counts = {} if counts is None else counts
    for pair in zip(ids[:-1],ids[1:]):
        counts[pair] = counts.get(pair,0)+1
    return counts

def merge_once(ids, pair, idx):
    """
    把ids中的每个形如pair的id对变成idx
    """
    new_ids = []
    i = 0
    while i < len(ids):
        if i == len(ids)-1:
            new_ids.append(ids[i])
            i += 1
        else:
            p1, p2 = ids[i], ids[i+1]
            if (p1,p2) == pair:
                new_ids.append(idx)
                i += 2
            else:
                new_ids.append(p1)
                i += 1
    return new_ids

def do_merge(ids, merges):
    """
    使用merges字典把ids合并为简化的列表
    例如,[1,2,3,1,2,3,4,1,2],{(1,2):5,(5,3):6}-->[6,6,4,5]
    """
    new_ids = ids
    while len(new_ids) >= 2:
        # 统计id列表的id对
        counts = get_stats(new_ids)
        # 选择id最小的merge对,找不到可以合并的pair时跳出
        counts_in_merges = {k:v for k,v in counts.items() if k in merges.keys()}
        if len(counts_in_merges.keys())==0:
            break
        pair = min(counts_in_merges, key=lambda p: merges[p])
        # 进行一次合并,把pair变成merges[pair]
        new_ids = merge_once(new_ids, pair, merges[pair])
    return new_ids

class SpecialToken:

    def __init__(self, name):
        self.name = name

    def __str__(self):
        return self.name
    
    def __repr__(self):
        return f"SpecialToken({self.name})"
    
    def __eq__(self, other):
        if isinstance(other, SpecialToken):
            return self.name == other.name
        return False
    
    def __hash__(self):
        return hash(self.name)

class Tokenizer:

    def __init__(self, pattern):
        self.merges = {}
        self.pattern_string = pattern
        self.pattern = regex.compile(pattern)
        self.vocab = {idx:bytes([idx]) for idx in range(256)}
        self.vocab_size = 256
        self.special_tokens = {}
        self.special_tokens_inv = {}
    
    def train(self, vocab_size, dataloader, merge_increase_per_loop=1):
        """
        训练Tokenizer,使得token数量最终达到vocab_size
        每次使用dataloader加载一组新的文本,使用已有的merges进行合并后,
        使用bpe算法找到merge_increase_per_loop个高频token对,加到self.merges中
        直到token总量达标为止
        """
        assert vocab_size >= self.vocab_size

        # 循环获取批量的文本
        for text_batch in dataloader:
            # 处理文本列表,把每条文本划分为文本块,把全部文本块合并为一个List
            text_chunks = []
            for text in text_batch:
                text_chunks += regex.findall(self.pattern, text)
            # 把文本块预处理为字节形式(即0~255的整数的列表)
            ids = [list(ch.encode('utf-8')) for ch in text_chunks]
            # 使用已有的merge更新ids
            ids = [do_merge(idlist,self.merges) for idlist in ids]
            for i in range(merge_increase_per_loop):
                # 统计数据中的id对
                counts = None
                for idlist in ids:
                    counts = get_stats(idlist, counts)
                # 找到频率最高的对
                pair = max(counts, key=lambda p: counts[p])
                if counts[pair] == 0:
                    break
                # 这对id对应的编号
                idx = self.vocab_size
                # 添加新token
                self.merges[pair] = idx
                self.vocab[idx] = self.vocab[pair[0]] + self.vocab[pair[1]]
                self.vocab_size += 1
                print(f"New merge: {pair}->{idx}")
                if self.vocab_size >= vocab_size:
                    return 0
                # 更新ids
                ids = [merge_once(idlist, pair, idx) for idlist in ids]
                if self.vocab_size % 1000 == 0:
                    self.save(f"tokenizer-{self.vocab_size}")
    
    def add_special_tokens(self, special_tokens):
        self.special_tokens = special_tokens
        self.special_tokens_inv = {v:k for k,v in special_tokens.items()}

    def build_vocab(self):
        self.vocab = {idx:bytes([idx]) for idx in range(256)}
        self.vocab_size = 256
        for (p1,p2),idx in self.merges.items():
            self.vocab[idx] = self.vocab[p1] + self.vocab[p2]
            self.vocab_size += 1
    
    def encode(self, text):
        text_chunks = regex.findall(self.pattern, text)
        all_ids = [do_merge(list(ch.encode('utf-8')),self.merges) for ch in text_chunks]
        ids = []
        for new_ids in all_ids:
            ids += new_ids
        return ids
    
    def encode_all(self, text_special_mix):
        ids = []
        for s in text_special_mix:
            if isinstance(s,str):
                ids += self.encode(s)
            elif isinstance(s,SpecialToken):
                ids.append(self.special_tokens[s])
            else:
                raise TypeError
        return ids
    
    def decode(self, ids):
        decoded = []
        curr_text_bytes = b""
        for i in range(len(ids)):
            if ids[i] in self.special_tokens_inv.keys():
                if curr_text_bytes:
                    decoded.append(curr_text_bytes.decode("utf-8",errors="replace"))
                    curr_text_bytes = b""
                decoded.append(SpecialToken(self.special_tokens_inv[ids[i]]))
            elif ids[i] in self.vocab.keys():
                curr_text_bytes += self.vocab[ids[i]]
                if i == len(ids) - 1:
                    decoded.append(curr_text_bytes.decode("utf-8",errors="replace"))
                    curr_text_bytes = b""
            else:
                print(f"{ids[i]}: Error token id.")
        return decoded

    def save(self, filename="tokenizer"):
        with open(f"{filename}.model","w") as f:
            f.write("Tokenizer V1\n")
            f.write(self.pattern_string+"\n")
            f.write(f"{len(self.special_tokens)}\n")
            for st, sid in self.special_tokens.items():
                f.write(f"{st.name} {sid}\n")
            for (p1,p2),idx in self.merges.items():
                f.write(f"{p1} {p2} {idx}\n")

        with open(f"{filename}.vocab","w") as f:
            f.write('Common Tokens:\n')
            for idx,bstr in self.vocab.items():
                f.write(f"{idx} {str(bstr)[2:-1]}\n")
            f.write('Special Tokens:\n')
            for idx,spt in self.special_tokens_inv.items():
                f.write(f"{idx} {spt.name}\n")

def load_tokenizer(filename="tokenizer.model"):
    with open(filename,"r",encoding="utf-8") as f:
        version = f.readline().strip()
        assert version == "Tokenizer V1"
        pat = f.readline().strip()
        num_spt = int(f.readline().strip())
        spts = {}
        for _ in range(num_spt):
            line = f.readline().strip()
            spt_name, spt_idx = line.split()
            spts[SpecialToken(spt_name)] = int(spt_idx)
        merges = {}
        line = f.readline().strip()
        while len(line)>=5:
            p1,p2,idx = line.split()
            merges[(int(p1),int(p2))] = int(idx)
            line = f.readline().strip()
    tokenizer = Tokenizer(pat)
    tokenizer.merges = merges
    tokenizer.build_vocab()
    tokenizer.add_special_tokens(spts)
    # 确保vocab包含所有特殊token,并正确设置vocab_size
    max_id = max(tokenizer.vocab.keys()) if tokenizer.vocab else 255
    for st, sid in spts.items():
        tokenizer.vocab[sid] = f"<{st.name}>".encode("utf-8", errors="replace")
        if sid > max_id:
            max_id = sid
    tokenizer.vocab_size = max_id + 1
    return tokenizer