File size: 1,106 Bytes
e10f35b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import jsonlines
from torch.utils.data import Dataset, DataLoader
from tokenizer import SpecialToken, Tokenizer, load_tokenizer

class MyDataset(Dataset):
    def __init__(self, file_path):
        data = []
        with jsonlines.open(file_path, 'r') as f:
            for obj in f:
                data.append(obj['text'])
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# 创建数据集
file_path = 'data/zhwiki.jsonl'
dataset = MyDataset(file_path)

# 创建DataLoader
batch_size = 1024
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 创建Tokenizer
# GPT2pattern = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
# tokenizer = Tokenizer(GPT2pattern)

# 添加结束符号
# tokenizer.add_special_tokens({SpecialToken("<|endoftext|>"):50303})

# 加载tokenizer
tokenizer = load_tokenizer("tokenizer.model")

# 训练
tokenizer.train(50303, dataloader, merge_increase_per_loop=20)

# 保存
tokenizer.save()