import regex def get_stats(ids, counts=None): """ 统计一列整数中相邻两个数组成的数对出现的频率 """ if len(ids) <= 1: return counts counts = {} if counts is None else counts for pair in zip(ids[:-1],ids[1:]): counts[pair] = counts.get(pair,0)+1 return counts def merge_once(ids, pair, idx): """ 把ids中的每个形如pair的id对变成idx """ new_ids = [] i = 0 while i < len(ids): if i == len(ids)-1: new_ids.append(ids[i]) i += 1 else: p1, p2 = ids[i], ids[i+1] if (p1,p2) == pair: new_ids.append(idx) i += 2 else: new_ids.append(p1) i += 1 return new_ids def do_merge(ids, merges): """ 使用merges字典把ids合并为简化的列表 例如,[1,2,3,1,2,3,4,1,2],{(1,2):5,(5,3):6}-->[6,6,4,5] """ new_ids = ids while len(new_ids) >= 2: # 统计id列表的id对 counts = get_stats(new_ids) # 选择id最小的merge对,找不到可以合并的pair时跳出 counts_in_merges = {k:v for k,v in counts.items() if k in merges.keys()} if len(counts_in_merges.keys())==0: break pair = min(counts_in_merges, key=lambda p: merges[p]) # 进行一次合并,把pair变成merges[pair] new_ids = merge_once(new_ids, pair, merges[pair]) return new_ids class SpecialToken: def __init__(self, name): self.name = name def __str__(self): return self.name def __repr__(self): return f"SpecialToken({self.name})" def __eq__(self, other): if isinstance(other, SpecialToken): return self.name == other.name return False def __hash__(self): return hash(self.name) class Tokenizer: def __init__(self, pattern): self.merges = {} self.pattern_string = pattern self.pattern = regex.compile(pattern) self.vocab = {idx:bytes([idx]) for idx in range(256)} self.vocab_size = 256 self.special_tokens = {} self.special_tokens_inv = {} def train(self, vocab_size, dataloader, merge_increase_per_loop=1): """ 训练Tokenizer,使得token数量最终达到vocab_size 每次使用dataloader加载一组新的文本,使用已有的merges进行合并后, 使用bpe算法找到merge_increase_per_loop个高频token对,加到self.merges中 直到token总量达标为止 """ assert vocab_size >= self.vocab_size # 循环获取批量的文本 for text_batch in dataloader: # 处理文本列表,把每条文本划分为文本块,把全部文本块合并为一个List text_chunks = [] for text in text_batch: text_chunks += regex.findall(self.pattern, text) # 把文本块预处理为字节形式(即0~255的整数的列表) ids = [list(ch.encode('utf-8')) for ch in text_chunks] # 使用已有的merge更新ids ids = [do_merge(idlist,self.merges) for idlist in ids] for i in range(merge_increase_per_loop): # 统计数据中的id对 counts = None for idlist in ids: counts = get_stats(idlist, counts) # 找到频率最高的对 pair = max(counts, key=lambda p: counts[p]) if counts[pair] == 0: break # 这对id对应的编号 idx = self.vocab_size # 添加新token self.merges[pair] = idx self.vocab[idx] = self.vocab[pair[0]] + self.vocab[pair[1]] self.vocab_size += 1 print(f"New merge: {pair}->{idx}") if self.vocab_size >= vocab_size: return 0 # 更新ids ids = [merge_once(idlist, pair, idx) for idlist in ids] if self.vocab_size % 1000 == 0: self.save(f"tokenizer-{self.vocab_size}") def add_special_tokens(self, special_tokens): self.special_tokens = special_tokens self.special_tokens_inv = {v:k for k,v in special_tokens.items()} def build_vocab(self): self.vocab = {idx:bytes([idx]) for idx in range(256)} self.vocab_size = 256 for (p1,p2),idx in self.merges.items(): self.vocab[idx] = self.vocab[p1] + self.vocab[p2] self.vocab_size += 1 def encode(self, text): text_chunks = regex.findall(self.pattern, text) all_ids = [do_merge(list(ch.encode('utf-8')),self.merges) for ch in text_chunks] ids = [] for new_ids in all_ids: ids += new_ids return ids def encode_all(self, text_special_mix): ids = [] for s in text_special_mix: if isinstance(s,str): ids += self.encode(s) elif isinstance(s,SpecialToken): ids.append(self.special_tokens[s]) else: raise TypeError return ids def decode(self, ids): decoded = [] curr_text_bytes = b"" for i in range(len(ids)): if ids[i] in self.special_tokens_inv.keys(): if curr_text_bytes: decoded.append(curr_text_bytes.decode("utf-8",errors="replace")) curr_text_bytes = b"" decoded.append(SpecialToken(self.special_tokens_inv[ids[i]])) elif ids[i] in self.vocab.keys(): curr_text_bytes += self.vocab[ids[i]] if i == len(ids) - 1: decoded.append(curr_text_bytes.decode("utf-8",errors="replace")) curr_text_bytes = b"" else: print(f"{ids[i]}: Error token id.") return decoded def save(self, filename="tokenizer"): with open(f"{filename}.model","w") as f: f.write("Tokenizer V1\n") f.write(self.pattern_string+"\n") f.write(f"{len(self.special_tokens)}\n") for st, sid in self.special_tokens.items(): f.write(f"{st.name} {sid}\n") for (p1,p2),idx in self.merges.items(): f.write(f"{p1} {p2} {idx}\n") with open(f"{filename}.vocab","w") as f: f.write('Common Tokens:\n') for idx,bstr in self.vocab.items(): f.write(f"{idx} {str(bstr)[2:-1]}\n") f.write('Special Tokens:\n') for idx,spt in self.special_tokens_inv.items(): f.write(f"{idx} {spt.name}\n") def load_tokenizer(filename="tokenizer.model"): with open(filename,"r",encoding="utf-8") as f: version = f.readline().strip() assert version == "Tokenizer V1" pat = f.readline().strip() num_spt = int(f.readline().strip()) spts = {} for _ in range(num_spt): line = f.readline().strip() spt_name, spt_idx = line.split() spts[SpecialToken(spt_name)] = int(spt_idx) merges = {} line = f.readline().strip() while len(line)>=5: p1,p2,idx = line.split() merges[(int(p1),int(p2))] = int(idx) line = f.readline().strip() tokenizer = Tokenizer(pat) tokenizer.merges = merges tokenizer.build_vocab() tokenizer.add_special_tokens(spts) # 确保vocab包含所有特殊token,并正确设置vocab_size max_id = max(tokenizer.vocab.keys()) if tokenizer.vocab else 255 for st, sid in spts.items(): tokenizer.vocab[sid] = f"<{st.name}>".encode("utf-8", errors="replace") if sid > max_id: max_id = sid tokenizer.vocab_size = max_id + 1 return tokenizer