File size: 1,692 Bytes
383bfb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from torch.utils.data import DataLoader
from datasets.dataset import MyDataset
from torch.utils.data.distributed import DistributedSampler
from datasets.preprossing import *
import os

def get_dataloader(args):

    src_lang = SrcLang(args.vocab_src_path)
    tgt_lang = TgtLang(args.vocab_tgt_path)

    train_data_path = os.path.join(args.dataset_dir, args.dataset, 'train.json')
    train_pairs = get_raw_pairs(train_data_path)
    test_data_path = os.path.join(args.dataset_dir, args.dataset, 'test.json')
    test_pairs = get_raw_pairs(test_data_path)

    train_data = MyDataset(args, train_pairs, src_lang, tgt_lang, is_train=True)
    train_sampler = DistributedSampler(train_data, shuffle=True)
    train_loader = DataLoader(dataset=train_data, \
                              batch_size=int(args.batch_size/args.nprocs), \
                              pin_memory=True, \
                              collate_fn=collater(args), \
                              num_workers=args.workers, \
                              sampler=train_sampler
                              )
                              
    test_data = MyDataset(args, test_pairs, src_lang, tgt_lang, is_train=False)
    test_sampler = DistributedSampler(test_data, shuffle=False)
    test_loader = DataLoader(dataset=test_data, \
                            batch_size=1, \
                            pin_memory=True, \
                            collate_fn=collater(args), \
                            num_workers=args.workers, \
                            sampler=test_sampler
                            )
                            
    return train_loader, train_sampler, test_loader, src_lang, tgt_lang