| import jsonlines | |
| from torch.utils.data import Dataset, DataLoader | |
| from tokenizer import SpecialToken, Tokenizer, load_tokenizer | |
| class MyDataset(Dataset): | |
| def __init__(self, file_path): | |
| data = [] | |
| with jsonlines.open(file_path, 'r') as f: | |
| for obj in f: | |
| data.append(obj['text']) | |
| self.data = data | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| return self.data[idx] | |
| # 创建数据集 | |
| file_path = 'data/zhwiki.jsonl' | |
| dataset = MyDataset(file_path) | |
| # 创建DataLoader | |
| batch_size = 1024 | |
| dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) | |
| # 创建Tokenizer | |
| # GPT2pattern = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" | |
| # tokenizer = Tokenizer(GPT2pattern) | |
| # 添加结束符号 | |
| # tokenizer.add_special_tokens({SpecialToken("<|endoftext|>"):50303}) | |
| # 加载tokenizer | |
| tokenizer = load_tokenizer("tokenizer.model") | |
| # 训练 | |
| tokenizer.train(50303, dataloader, merge_increase_per_loop=20) | |
| # 保存 | |
| tokenizer.save() | |