Spaces:
Sleeping
Sleeping
| from torch.utils.data import Dataset | |
| import csv | |
| from tokenizers import Tokenizer | |
| import torch | |
| import os | |
| import pickle | |
| from src.config import Config | |
| class TranslateDataset(Dataset): | |
| def __init__(self, config: Config): | |
| super().__init__() | |
| self.config = config | |
| self.tokenizer: Tokenizer = Tokenizer.from_file(config.tokenizer_file) | |
| self.pad_id = self.tokenizer.token_to_id("[PAD]") | |
| self.pairs = [] | |
| if os.path.exists(config.data_cache_dir) and config.use_cache: | |
| with open(config.data_cache_dir, "rb") as f: | |
| self.pairs = pickle.load(f) | |
| else: | |
| with open(self.config.wmt_zh_en_path, mode="r", encoding="utf-8") as f: | |
| reader = csv.DictReader(f) | |
| for line in reader: | |
| self.pairs.append((line["0"], line["1"])) | |
| if config.use_cache: | |
| with open(config.data_cache_dir, "wb") as cache_f: | |
| pickle.dump(self.pairs, cache_f) | |
| def __len__(self): | |
| return len(self.pairs) | |
| def encode(self, text): | |
| ids = self.tokenizer.encode(text).ids | |
| if len(ids) > self.config.max_len: | |
| ids = ids[: self.config.max_len] | |
| pad_len = self.config.max_len - len(ids) | |
| if pad_len > 0: | |
| ids = ids + [self.pad_id] * pad_len | |
| pad_mask = [False if i == self.pad_id else True for i in ids] | |
| return torch.tensor(ids, dtype=torch.long), torch.tensor( | |
| pad_mask, dtype=torch.bool | |
| ) | |
| def __getitem__(self, idx): | |
| zh, en = self.pairs[idx] | |
| zh_id, zh_pad = self.encode(zh) | |
| en_id, en_pad = self.encode(en) | |
| return dict( | |
| src=zh_id, | |
| src_pad_mask=zh_pad, | |
| tgt=en_id[:-1], | |
| tgt_pad_mask=en_pad[:-1], | |
| label=en_id[1:], | |
| ) | |