Spaces:
Sleeping
Sleeping
File size: 1,692 Bytes
383bfb8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
from torch.utils.data import DataLoader
from datasets.dataset import MyDataset
from torch.utils.data.distributed import DistributedSampler
from datasets.preprossing import *
import os
def get_dataloader(args):
src_lang = SrcLang(args.vocab_src_path)
tgt_lang = TgtLang(args.vocab_tgt_path)
train_data_path = os.path.join(args.dataset_dir, args.dataset, 'train.json')
train_pairs = get_raw_pairs(train_data_path)
test_data_path = os.path.join(args.dataset_dir, args.dataset, 'test.json')
test_pairs = get_raw_pairs(test_data_path)
train_data = MyDataset(args, train_pairs, src_lang, tgt_lang, is_train=True)
train_sampler = DistributedSampler(train_data, shuffle=True)
train_loader = DataLoader(dataset=train_data, \
batch_size=int(args.batch_size/args.nprocs), \
pin_memory=True, \
collate_fn=collater(args), \
num_workers=args.workers, \
sampler=train_sampler
)
test_data = MyDataset(args, test_pairs, src_lang, tgt_lang, is_train=False)
test_sampler = DistributedSampler(test_data, shuffle=False)
test_loader = DataLoader(dataset=test_data, \
batch_size=1, \
pin_memory=True, \
collate_fn=collater(args), \
num_workers=args.workers, \
sampler=test_sampler
)
return train_loader, train_sampler, test_loader, src_lang, tgt_lang
|