File size: 1,106 Bytes
e10f35b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 | import jsonlines
from torch.utils.data import Dataset, DataLoader
from tokenizer import SpecialToken, Tokenizer, load_tokenizer
class MyDataset(Dataset):
def __init__(self, file_path):
data = []
with jsonlines.open(file_path, 'r') as f:
for obj in f:
data.append(obj['text'])
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
# 创建数据集
file_path = 'data/zhwiki.jsonl'
dataset = MyDataset(file_path)
# 创建DataLoader
batch_size = 1024
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 创建Tokenizer
# GPT2pattern = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
# tokenizer = Tokenizer(GPT2pattern)
# 添加结束符号
# tokenizer.add_special_tokens({SpecialToken("<|endoftext|>"):50303})
# 加载tokenizer
tokenizer = load_tokenizer("tokenizer.model")
# 训练
tokenizer.train(50303, dataloader, merge_increase_per_loop=20)
# 保存
tokenizer.save()
|