File size: 975 Bytes
2f610df
 
 
 
 
 
da3a6cf
 
 
2f610df
 
da3a6cf
 
 
 
 
2f610df
da3a6cf
 
2f610df
da3a6cf
 
 
 
2f610df
 
da3a6cf
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from constants.tokens import PAD_ID

def encode_with_specials(token_ids, vocab, add_sos_eos=False):
  if add_sos_eos:
    return [vocab['<sos>']] + [vocab[t] for t in token_ids] + [vocab['<eos>']]
  return [vocab[t] for t in token_ids]

class WordGenDataset(Dataset):
  def __init__(self, inputs, outputs, vocab, max_len=64):
    self.inputs = inputs
    self.outputs = outputs
    self.vocab = vocab
    self.max_len = max_len

  def __len__(self):
    return len(self.inputs)

  def __getitem__(self, idx):
    x = encode_with_specials(self.inputs[idx], self.vocab)
    y = encode_with_specials(self.outputs[idx], self.vocab, add_sos_eos=True)
    return torch.tensor(x), torch.tensor(y)

def collate_fn(batch):
  xs, ys = zip(*batch)
  xs = pad_sequence(xs, batch_first=True, padding_value=PAD_ID)
  ys = pad_sequence(ys, batch_first=True, padding_value=PAD_ID)
  return xs, ys