| import regex |
|
|
| def get_stats(ids, counts=None): |
| """ |
| 统计一列整数中相邻两个数组成的数对出现的频率 |
| """ |
| if len(ids) <= 1: |
| return counts |
| counts = {} if counts is None else counts |
| for pair in zip(ids[:-1],ids[1:]): |
| counts[pair] = counts.get(pair,0)+1 |
| return counts |
|
|
| def merge_once(ids, pair, idx): |
| """ |
| 把ids中的每个形如pair的id对变成idx |
| """ |
| new_ids = [] |
| i = 0 |
| while i < len(ids): |
| if i == len(ids)-1: |
| new_ids.append(ids[i]) |
| i += 1 |
| else: |
| p1, p2 = ids[i], ids[i+1] |
| if (p1,p2) == pair: |
| new_ids.append(idx) |
| i += 2 |
| else: |
| new_ids.append(p1) |
| i += 1 |
| return new_ids |
|
|
| def do_merge(ids, merges): |
| """ |
| 使用merges字典把ids合并为简化的列表 |
| 例如,[1,2,3,1,2,3,4,1,2],{(1,2):5,(5,3):6}-->[6,6,4,5] |
| """ |
| new_ids = ids |
| while len(new_ids) >= 2: |
| |
| counts = get_stats(new_ids) |
| |
| counts_in_merges = {k:v for k,v in counts.items() if k in merges.keys()} |
| if len(counts_in_merges.keys())==0: |
| break |
| pair = min(counts_in_merges, key=lambda p: merges[p]) |
| |
| new_ids = merge_once(new_ids, pair, merges[pair]) |
| return new_ids |
|
|
| class SpecialToken: |
|
|
| def __init__(self, name): |
| self.name = name |
|
|
| def __str__(self): |
| return self.name |
| |
| def __repr__(self): |
| return f"SpecialToken({self.name})" |
| |
| def __eq__(self, other): |
| if isinstance(other, SpecialToken): |
| return self.name == other.name |
| return False |
| |
| def __hash__(self): |
| return hash(self.name) |
|
|
| class Tokenizer: |
|
|
| def __init__(self, pattern): |
| self.merges = {} |
| self.pattern_string = pattern |
| self.pattern = regex.compile(pattern) |
| self.vocab = {idx:bytes([idx]) for idx in range(256)} |
| self.vocab_size = 256 |
| self.special_tokens = {} |
| self.special_tokens_inv = {} |
| |
| def train(self, vocab_size, dataloader, merge_increase_per_loop=1): |
| """ |
| 训练Tokenizer,使得token数量最终达到vocab_size |
| 每次使用dataloader加载一组新的文本,使用已有的merges进行合并后, |
| 使用bpe算法找到merge_increase_per_loop个高频token对,加到self.merges中 |
| 直到token总量达标为止 |
| """ |
| assert vocab_size >= self.vocab_size |
|
|
| |
| for text_batch in dataloader: |
| |
| text_chunks = [] |
| for text in text_batch: |
| text_chunks += regex.findall(self.pattern, text) |
| |
| ids = [list(ch.encode('utf-8')) for ch in text_chunks] |
| |
| ids = [do_merge(idlist,self.merges) for idlist in ids] |
| for i in range(merge_increase_per_loop): |
| |
| counts = None |
| for idlist in ids: |
| counts = get_stats(idlist, counts) |
| |
| pair = max(counts, key=lambda p: counts[p]) |
| if counts[pair] == 0: |
| break |
| |
| idx = self.vocab_size |
| |
| self.merges[pair] = idx |
| self.vocab[idx] = self.vocab[pair[0]] + self.vocab[pair[1]] |
| self.vocab_size += 1 |
| print(f"New merge: {pair}->{idx}") |
| if self.vocab_size >= vocab_size: |
| return 0 |
| |
| ids = [merge_once(idlist, pair, idx) for idlist in ids] |
| if self.vocab_size % 1000 == 0: |
| self.save(f"tokenizer-{self.vocab_size}") |
| |
| def add_special_tokens(self, special_tokens): |
| self.special_tokens = special_tokens |
| self.special_tokens_inv = {v:k for k,v in special_tokens.items()} |
|
|
| def build_vocab(self): |
| self.vocab = {idx:bytes([idx]) for idx in range(256)} |
| self.vocab_size = 256 |
| for (p1,p2),idx in self.merges.items(): |
| self.vocab[idx] = self.vocab[p1] + self.vocab[p2] |
| self.vocab_size += 1 |
| |
| def encode(self, text): |
| text_chunks = regex.findall(self.pattern, text) |
| all_ids = [do_merge(list(ch.encode('utf-8')),self.merges) for ch in text_chunks] |
| ids = [] |
| for new_ids in all_ids: |
| ids += new_ids |
| return ids |
| |
| def encode_all(self, text_special_mix): |
| ids = [] |
| for s in text_special_mix: |
| if isinstance(s,str): |
| ids += self.encode(s) |
| elif isinstance(s,SpecialToken): |
| ids.append(self.special_tokens[s]) |
| else: |
| raise TypeError |
| return ids |
| |
| def decode(self, ids): |
| decoded = [] |
| curr_text_bytes = b"" |
| for i in range(len(ids)): |
| if ids[i] in self.special_tokens_inv.keys(): |
| if curr_text_bytes: |
| decoded.append(curr_text_bytes.decode("utf-8",errors="replace")) |
| curr_text_bytes = b"" |
| decoded.append(SpecialToken(self.special_tokens_inv[ids[i]])) |
| elif ids[i] in self.vocab.keys(): |
| curr_text_bytes += self.vocab[ids[i]] |
| if i == len(ids) - 1: |
| decoded.append(curr_text_bytes.decode("utf-8",errors="replace")) |
| curr_text_bytes = b"" |
| else: |
| print(f"{ids[i]}: Error token id.") |
| return decoded |
|
|
| def save(self, filename="tokenizer"): |
| with open(f"{filename}.model","w") as f: |
| f.write("Tokenizer V1\n") |
| f.write(self.pattern_string+"\n") |
| f.write(f"{len(self.special_tokens)}\n") |
| for st, sid in self.special_tokens.items(): |
| f.write(f"{st.name} {sid}\n") |
| for (p1,p2),idx in self.merges.items(): |
| f.write(f"{p1} {p2} {idx}\n") |
|
|
| with open(f"{filename}.vocab","w") as f: |
| f.write('Common Tokens:\n') |
| for idx,bstr in self.vocab.items(): |
| f.write(f"{idx} {str(bstr)[2:-1]}\n") |
| f.write('Special Tokens:\n') |
| for idx,spt in self.special_tokens_inv.items(): |
| f.write(f"{idx} {spt.name}\n") |
|
|
| def load_tokenizer(filename="tokenizer.model"): |
| with open(filename,"r",encoding="utf-8") as f: |
| version = f.readline().strip() |
| assert version == "Tokenizer V1" |
| pat = f.readline().strip() |
| num_spt = int(f.readline().strip()) |
| spts = {} |
| for _ in range(num_spt): |
| line = f.readline().strip() |
| spt_name, spt_idx = line.split() |
| spts[SpecialToken(spt_name)] = int(spt_idx) |
| merges = {} |
| line = f.readline().strip() |
| while len(line)>=5: |
| p1,p2,idx = line.split() |
| merges[(int(p1),int(p2))] = int(idx) |
| line = f.readline().strip() |
| tokenizer = Tokenizer(pat) |
| tokenizer.merges = merges |
| tokenizer.build_vocab() |
| tokenizer.add_special_tokens(spts) |
| |
| max_id = max(tokenizer.vocab.keys()) if tokenizer.vocab else 255 |
| for st, sid in spts.items(): |
| tokenizer.vocab[sid] = f"<{st.name}>".encode("utf-8", errors="replace") |
| if sid > max_id: |
| max_id = sid |
| tokenizer.vocab_size = max_id + 1 |
| return tokenizer |
|
|
|
|
|
|