LightNovelModel-Alpha / tokenizer.py
hugfaceguy0001's picture
upload model and train/infer codes
e10f35b verified
Raw
History Blame Contribute Delete
7.96 kB
import regex
def get_stats(ids, counts=None):
"""
统计一列整数中相邻两个数组成的数对出现的频率
"""
if len(ids) <= 1:
return counts
counts = {} if counts is None else counts
for pair in zip(ids[:-1],ids[1:]):
counts[pair] = counts.get(pair,0)+1
return counts
def merge_once(ids, pair, idx):
"""
把ids中的每个形如pair的id对变成idx
"""
new_ids = []
i = 0
while i < len(ids):
if i == len(ids)-1:
new_ids.append(ids[i])
i += 1
else:
p1, p2 = ids[i], ids[i+1]
if (p1,p2) == pair:
new_ids.append(idx)
i += 2
else:
new_ids.append(p1)
i += 1
return new_ids
def do_merge(ids, merges):
"""
使用merges字典把ids合并为简化的列表
例如,[1,2,3,1,2,3,4,1,2],{(1,2):5,(5,3):6}-->[6,6,4,5]
"""
new_ids = ids
while len(new_ids) >= 2:
# 统计id列表的id对
counts = get_stats(new_ids)
# 选择id最小的merge对,找不到可以合并的pair时跳出
counts_in_merges = {k:v for k,v in counts.items() if k in merges.keys()}
if len(counts_in_merges.keys())==0:
break
pair = min(counts_in_merges, key=lambda p: merges[p])
# 进行一次合并,把pair变成merges[pair]
new_ids = merge_once(new_ids, pair, merges[pair])
return new_ids
class SpecialToken:
def __init__(self, name):
self.name = name
def __str__(self):
return self.name
def __repr__(self):
return f"SpecialToken({self.name})"
def __eq__(self, other):
if isinstance(other, SpecialToken):
return self.name == other.name
return False
def __hash__(self):
return hash(self.name)
class Tokenizer:
def __init__(self, pattern):
self.merges = {}
self.pattern_string = pattern
self.pattern = regex.compile(pattern)
self.vocab = {idx:bytes([idx]) for idx in range(256)}
self.vocab_size = 256
self.special_tokens = {}
self.special_tokens_inv = {}
def train(self, vocab_size, dataloader, merge_increase_per_loop=1):
"""
训练Tokenizer,使得token数量最终达到vocab_size
每次使用dataloader加载一组新的文本,使用已有的merges进行合并后,
使用bpe算法找到merge_increase_per_loop个高频token对,加到self.merges中
直到token总量达标为止
"""
assert vocab_size >= self.vocab_size
# 循环获取批量的文本
for text_batch in dataloader:
# 处理文本列表,把每条文本划分为文本块,把全部文本块合并为一个List
text_chunks = []
for text in text_batch:
text_chunks += regex.findall(self.pattern, text)
# 把文本块预处理为字节形式(即0~255的整数的列表)
ids = [list(ch.encode('utf-8')) for ch in text_chunks]
# 使用已有的merge更新ids
ids = [do_merge(idlist,self.merges) for idlist in ids]
for i in range(merge_increase_per_loop):
# 统计数据中的id对
counts = None
for idlist in ids:
counts = get_stats(idlist, counts)
# 找到频率最高的对
pair = max(counts, key=lambda p: counts[p])
if counts[pair] == 0:
break
# 这对id对应的编号
idx = self.vocab_size
# 添加新token
self.merges[pair] = idx
self.vocab[idx] = self.vocab[pair[0]] + self.vocab[pair[1]]
self.vocab_size += 1
print(f"New merge: {pair}->{idx}")
if self.vocab_size >= vocab_size:
return 0
# 更新ids
ids = [merge_once(idlist, pair, idx) for idlist in ids]
if self.vocab_size % 1000 == 0:
self.save(f"tokenizer-{self.vocab_size}")
def add_special_tokens(self, special_tokens):
self.special_tokens = special_tokens
self.special_tokens_inv = {v:k for k,v in special_tokens.items()}
def build_vocab(self):
self.vocab = {idx:bytes([idx]) for idx in range(256)}
self.vocab_size = 256
for (p1,p2),idx in self.merges.items():
self.vocab[idx] = self.vocab[p1] + self.vocab[p2]
self.vocab_size += 1
def encode(self, text):
text_chunks = regex.findall(self.pattern, text)
all_ids = [do_merge(list(ch.encode('utf-8')),self.merges) for ch in text_chunks]
ids = []
for new_ids in all_ids:
ids += new_ids
return ids
def encode_all(self, text_special_mix):
ids = []
for s in text_special_mix:
if isinstance(s,str):
ids += self.encode(s)
elif isinstance(s,SpecialToken):
ids.append(self.special_tokens[s])
else:
raise TypeError
return ids
def decode(self, ids):
decoded = []
curr_text_bytes = b""
for i in range(len(ids)):
if ids[i] in self.special_tokens_inv.keys():
if curr_text_bytes:
decoded.append(curr_text_bytes.decode("utf-8",errors="replace"))
curr_text_bytes = b""
decoded.append(SpecialToken(self.special_tokens_inv[ids[i]]))
elif ids[i] in self.vocab.keys():
curr_text_bytes += self.vocab[ids[i]]
if i == len(ids) - 1:
decoded.append(curr_text_bytes.decode("utf-8",errors="replace"))
curr_text_bytes = b""
else:
print(f"{ids[i]}: Error token id.")
return decoded
def save(self, filename="tokenizer"):
with open(f"{filename}.model","w") as f:
f.write("Tokenizer V1\n")
f.write(self.pattern_string+"\n")
f.write(f"{len(self.special_tokens)}\n")
for st, sid in self.special_tokens.items():
f.write(f"{st.name} {sid}\n")
for (p1,p2),idx in self.merges.items():
f.write(f"{p1} {p2} {idx}\n")
with open(f"{filename}.vocab","w") as f:
f.write('Common Tokens:\n')
for idx,bstr in self.vocab.items():
f.write(f"{idx} {str(bstr)[2:-1]}\n")
f.write('Special Tokens:\n')
for idx,spt in self.special_tokens_inv.items():
f.write(f"{idx} {spt.name}\n")
def load_tokenizer(filename="tokenizer.model"):
with open(filename,"r",encoding="utf-8") as f:
version = f.readline().strip()
assert version == "Tokenizer V1"
pat = f.readline().strip()
num_spt = int(f.readline().strip())
spts = {}
for _ in range(num_spt):
line = f.readline().strip()
spt_name, spt_idx = line.split()
spts[SpecialToken(spt_name)] = int(spt_idx)
merges = {}
line = f.readline().strip()
while len(line)>=5:
p1,p2,idx = line.split()
merges[(int(p1),int(p2))] = int(idx)
line = f.readline().strip()
tokenizer = Tokenizer(pat)
tokenizer.merges = merges
tokenizer.build_vocab()
tokenizer.add_special_tokens(spts)
# 确保vocab包含所有特殊token,并正确设置vocab_size
max_id = max(tokenizer.vocab.keys()) if tokenizer.vocab else 255
for st, sid in spts.items():
tokenizer.vocab[sid] = f"<{st.name}>".encode("utf-8", errors="replace")
if sid > max_id:
max_id = sid
tokenizer.vocab_size = max_id + 1
return tokenizer