import jsonlines from torch.utils.data import Dataset, DataLoader from tokenizer import SpecialToken, Tokenizer, load_tokenizer class MyDataset(Dataset): def __init__(self, file_path): data = [] with jsonlines.open(file_path, 'r') as f: for obj in f: data.append(obj['text']) self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] # 创建数据集 file_path = 'data/zhwiki.jsonl' dataset = MyDataset(file_path) # 创建DataLoader batch_size = 1024 dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # 创建Tokenizer # GPT2pattern = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" # tokenizer = Tokenizer(GPT2pattern) # 添加结束符号 # tokenizer.add_special_tokens({SpecialToken("<|endoftext|>"):50303}) # 加载tokenizer tokenizer = load_tokenizer("tokenizer.model") # 训练 tokenizer.train(50303, dataloader, merge_increase_per_loop=20) # 保存 tokenizer.save()