Tiny-Translator / src /dataset.py
caixiaoshun's picture
Upload 6 files
5153277 verified
from torch.utils.data import Dataset
import csv
from tokenizers import Tokenizer
import torch
import os
import pickle
from src.config import Config
class TranslateDataset(Dataset):
def __init__(self, config: Config):
super().__init__()
self.config = config
self.tokenizer: Tokenizer = Tokenizer.from_file(config.tokenizer_file)
self.pad_id = self.tokenizer.token_to_id("[PAD]")
self.pairs = []
if os.path.exists(config.data_cache_dir) and config.use_cache:
with open(config.data_cache_dir, "rb") as f:
self.pairs = pickle.load(f)
else:
with open(self.config.wmt_zh_en_path, mode="r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for line in reader:
self.pairs.append((line["0"], line["1"]))
if config.use_cache:
with open(config.data_cache_dir, "wb") as cache_f:
pickle.dump(self.pairs, cache_f)
def __len__(self):
return len(self.pairs)
def encode(self, text):
ids = self.tokenizer.encode(text).ids
if len(ids) > self.config.max_len:
ids = ids[: self.config.max_len]
pad_len = self.config.max_len - len(ids)
if pad_len > 0:
ids = ids + [self.pad_id] * pad_len
pad_mask = [False if i == self.pad_id else True for i in ids]
return torch.tensor(ids, dtype=torch.long), torch.tensor(
pad_mask, dtype=torch.bool
)
def __getitem__(self, idx):
zh, en = self.pairs[idx]
zh_id, zh_pad = self.encode(zh)
en_id, en_pad = self.encode(en)
return dict(
src=zh_id,
src_pad_mask=zh_pad,
tgt=en_id[:-1],
tgt_pad_mask=en_pad[:-1],
label=en_id[1:],
)