File size: 1,862 Bytes
5153277
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from torch.utils.data import Dataset
import csv
from tokenizers import Tokenizer
import torch
import os
import pickle
from src.config import Config


class TranslateDataset(Dataset):
    def __init__(self, config: Config):
        super().__init__()
        self.config = config
        self.tokenizer: Tokenizer = Tokenizer.from_file(config.tokenizer_file)
        self.pad_id = self.tokenizer.token_to_id("[PAD]")
        self.pairs = []
        if os.path.exists(config.data_cache_dir) and config.use_cache:
            with open(config.data_cache_dir, "rb") as f:
                self.pairs = pickle.load(f)
        else:
            with open(self.config.wmt_zh_en_path, mode="r", encoding="utf-8") as f:
                reader = csv.DictReader(f)
                for line in reader:
                    self.pairs.append((line["0"], line["1"]))
            if config.use_cache:
                with open(config.data_cache_dir, "wb") as cache_f:
                    pickle.dump(self.pairs, cache_f)

    def __len__(self):
        return len(self.pairs)

    def encode(self, text):
        ids = self.tokenizer.encode(text).ids

        if len(ids) > self.config.max_len:
            ids = ids[: self.config.max_len]

        pad_len = self.config.max_len - len(ids)

        if pad_len > 0:
            ids = ids + [self.pad_id] * pad_len
        pad_mask = [False if i == self.pad_id else True for i in ids]
        return torch.tensor(ids, dtype=torch.long), torch.tensor(
            pad_mask, dtype=torch.bool
        )

    def __getitem__(self, idx):
        zh, en = self.pairs[idx]

        zh_id, zh_pad = self.encode(zh)

        en_id, en_pad = self.encode(en)

        return dict(
            src=zh_id,
            src_pad_mask=zh_pad,
            tgt=en_id[:-1],
            tgt_pad_mask=en_pad[:-1],
            label=en_id[1:],
        )