File size: 975 Bytes
2f610df da3a6cf 2f610df da3a6cf 2f610df da3a6cf 2f610df da3a6cf 2f610df da3a6cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from constants.tokens import PAD_ID
def encode_with_specials(token_ids, vocab, add_sos_eos=False):
if add_sos_eos:
return [vocab['<sos>']] + [vocab[t] for t in token_ids] + [vocab['<eos>']]
return [vocab[t] for t in token_ids]
class WordGenDataset(Dataset):
def __init__(self, inputs, outputs, vocab, max_len=64):
self.inputs = inputs
self.outputs = outputs
self.vocab = vocab
self.max_len = max_len
def __len__(self):
return len(self.inputs)
def __getitem__(self, idx):
x = encode_with_specials(self.inputs[idx], self.vocab)
y = encode_with_specials(self.outputs[idx], self.vocab, add_sos_eos=True)
return torch.tensor(x), torch.tensor(y)
def collate_fn(batch):
xs, ys = zip(*batch)
xs = pad_sequence(xs, batch_first=True, padding_value=PAD_ID)
ys = pad_sequence(ys, batch_first=True, padding_value=PAD_ID)
return xs, ys
|