pgps-demo / datasets /__init__.py
asdfasdfdsafdsa's picture
Initial upload of PGPS demo with all dependencies
383bfb8 verified
from torch.utils.data import DataLoader
from datasets.dataset import MyDataset
from torch.utils.data.distributed import DistributedSampler
from datasets.preprossing import *
import os
def get_dataloader(args):
src_lang = SrcLang(args.vocab_src_path)
tgt_lang = TgtLang(args.vocab_tgt_path)
train_data_path = os.path.join(args.dataset_dir, args.dataset, 'train.json')
train_pairs = get_raw_pairs(train_data_path)
test_data_path = os.path.join(args.dataset_dir, args.dataset, 'test.json')
test_pairs = get_raw_pairs(test_data_path)
train_data = MyDataset(args, train_pairs, src_lang, tgt_lang, is_train=True)
train_sampler = DistributedSampler(train_data, shuffle=True)
train_loader = DataLoader(dataset=train_data, \
batch_size=int(args.batch_size/args.nprocs), \
pin_memory=True, \
collate_fn=collater(args), \
num_workers=args.workers, \
sampler=train_sampler
)
test_data = MyDataset(args, test_pairs, src_lang, tgt_lang, is_train=False)
test_sampler = DistributedSampler(test_data, shuffle=False)
test_loader = DataLoader(dataset=test_data, \
batch_size=1, \
pin_memory=True, \
collate_fn=collater(args), \
num_workers=args.workers, \
sampler=test_sampler
)
return train_loader, train_sampler, test_loader, src_lang, tgt_lang