File size: 7,956 Bytes
e10f35b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 | import regex
def get_stats(ids, counts=None):
"""
统计一列整数中相邻两个数组成的数对出现的频率
"""
if len(ids) <= 1:
return counts
counts = {} if counts is None else counts
for pair in zip(ids[:-1],ids[1:]):
counts[pair] = counts.get(pair,0)+1
return counts
def merge_once(ids, pair, idx):
"""
把ids中的每个形如pair的id对变成idx
"""
new_ids = []
i = 0
while i < len(ids):
if i == len(ids)-1:
new_ids.append(ids[i])
i += 1
else:
p1, p2 = ids[i], ids[i+1]
if (p1,p2) == pair:
new_ids.append(idx)
i += 2
else:
new_ids.append(p1)
i += 1
return new_ids
def do_merge(ids, merges):
"""
使用merges字典把ids合并为简化的列表
例如,[1,2,3,1,2,3,4,1,2],{(1,2):5,(5,3):6}-->[6,6,4,5]
"""
new_ids = ids
while len(new_ids) >= 2:
# 统计id列表的id对
counts = get_stats(new_ids)
# 选择id最小的merge对,找不到可以合并的pair时跳出
counts_in_merges = {k:v for k,v in counts.items() if k in merges.keys()}
if len(counts_in_merges.keys())==0:
break
pair = min(counts_in_merges, key=lambda p: merges[p])
# 进行一次合并,把pair变成merges[pair]
new_ids = merge_once(new_ids, pair, merges[pair])
return new_ids
class SpecialToken:
def __init__(self, name):
self.name = name
def __str__(self):
return self.name
def __repr__(self):
return f"SpecialToken({self.name})"
def __eq__(self, other):
if isinstance(other, SpecialToken):
return self.name == other.name
return False
def __hash__(self):
return hash(self.name)
class Tokenizer:
def __init__(self, pattern):
self.merges = {}
self.pattern_string = pattern
self.pattern = regex.compile(pattern)
self.vocab = {idx:bytes([idx]) for idx in range(256)}
self.vocab_size = 256
self.special_tokens = {}
self.special_tokens_inv = {}
def train(self, vocab_size, dataloader, merge_increase_per_loop=1):
"""
训练Tokenizer,使得token数量最终达到vocab_size
每次使用dataloader加载一组新的文本,使用已有的merges进行合并后,
使用bpe算法找到merge_increase_per_loop个高频token对,加到self.merges中
直到token总量达标为止
"""
assert vocab_size >= self.vocab_size
# 循环获取批量的文本
for text_batch in dataloader:
# 处理文本列表,把每条文本划分为文本块,把全部文本块合并为一个List
text_chunks = []
for text in text_batch:
text_chunks += regex.findall(self.pattern, text)
# 把文本块预处理为字节形式(即0~255的整数的列表)
ids = [list(ch.encode('utf-8')) for ch in text_chunks]
# 使用已有的merge更新ids
ids = [do_merge(idlist,self.merges) for idlist in ids]
for i in range(merge_increase_per_loop):
# 统计数据中的id对
counts = None
for idlist in ids:
counts = get_stats(idlist, counts)
# 找到频率最高的对
pair = max(counts, key=lambda p: counts[p])
if counts[pair] == 0:
break
# 这对id对应的编号
idx = self.vocab_size
# 添加新token
self.merges[pair] = idx
self.vocab[idx] = self.vocab[pair[0]] + self.vocab[pair[1]]
self.vocab_size += 1
print(f"New merge: {pair}->{idx}")
if self.vocab_size >= vocab_size:
return 0
# 更新ids
ids = [merge_once(idlist, pair, idx) for idlist in ids]
if self.vocab_size % 1000 == 0:
self.save(f"tokenizer-{self.vocab_size}")
def add_special_tokens(self, special_tokens):
self.special_tokens = special_tokens
self.special_tokens_inv = {v:k for k,v in special_tokens.items()}
def build_vocab(self):
self.vocab = {idx:bytes([idx]) for idx in range(256)}
self.vocab_size = 256
for (p1,p2),idx in self.merges.items():
self.vocab[idx] = self.vocab[p1] + self.vocab[p2]
self.vocab_size += 1
def encode(self, text):
text_chunks = regex.findall(self.pattern, text)
all_ids = [do_merge(list(ch.encode('utf-8')),self.merges) for ch in text_chunks]
ids = []
for new_ids in all_ids:
ids += new_ids
return ids
def encode_all(self, text_special_mix):
ids = []
for s in text_special_mix:
if isinstance(s,str):
ids += self.encode(s)
elif isinstance(s,SpecialToken):
ids.append(self.special_tokens[s])
else:
raise TypeError
return ids
def decode(self, ids):
decoded = []
curr_text_bytes = b""
for i in range(len(ids)):
if ids[i] in self.special_tokens_inv.keys():
if curr_text_bytes:
decoded.append(curr_text_bytes.decode("utf-8",errors="replace"))
curr_text_bytes = b""
decoded.append(SpecialToken(self.special_tokens_inv[ids[i]]))
elif ids[i] in self.vocab.keys():
curr_text_bytes += self.vocab[ids[i]]
if i == len(ids) - 1:
decoded.append(curr_text_bytes.decode("utf-8",errors="replace"))
curr_text_bytes = b""
else:
print(f"{ids[i]}: Error token id.")
return decoded
def save(self, filename="tokenizer"):
with open(f"{filename}.model","w") as f:
f.write("Tokenizer V1\n")
f.write(self.pattern_string+"\n")
f.write(f"{len(self.special_tokens)}\n")
for st, sid in self.special_tokens.items():
f.write(f"{st.name} {sid}\n")
for (p1,p2),idx in self.merges.items():
f.write(f"{p1} {p2} {idx}\n")
with open(f"{filename}.vocab","w") as f:
f.write('Common Tokens:\n')
for idx,bstr in self.vocab.items():
f.write(f"{idx} {str(bstr)[2:-1]}\n")
f.write('Special Tokens:\n')
for idx,spt in self.special_tokens_inv.items():
f.write(f"{idx} {spt.name}\n")
def load_tokenizer(filename="tokenizer.model"):
with open(filename,"r",encoding="utf-8") as f:
version = f.readline().strip()
assert version == "Tokenizer V1"
pat = f.readline().strip()
num_spt = int(f.readline().strip())
spts = {}
for _ in range(num_spt):
line = f.readline().strip()
spt_name, spt_idx = line.split()
spts[SpecialToken(spt_name)] = int(spt_idx)
merges = {}
line = f.readline().strip()
while len(line)>=5:
p1,p2,idx = line.split()
merges[(int(p1),int(p2))] = int(idx)
line = f.readline().strip()
tokenizer = Tokenizer(pat)
tokenizer.merges = merges
tokenizer.build_vocab()
tokenizer.add_special_tokens(spts)
# 确保vocab包含所有特殊token,并正确设置vocab_size
max_id = max(tokenizer.vocab.keys()) if tokenizer.vocab else 255
for st, sid in spts.items():
tokenizer.vocab[sid] = f"<{st.name}>".encode("utf-8", errors="replace")
if sid > max_id:
max_id = sid
tokenizer.vocab_size = max_id + 1
return tokenizer
|