Spaces:
Sleeping
Sleeping
| from torch.utils.data import DataLoader | |
| from datasets.dataset import MyDataset | |
| from torch.utils.data.distributed import DistributedSampler | |
| from datasets.preprossing import * | |
| import os | |
| def get_dataloader(args): | |
| src_lang = SrcLang(args.vocab_src_path) | |
| tgt_lang = TgtLang(args.vocab_tgt_path) | |
| train_data_path = os.path.join(args.dataset_dir, args.dataset, 'train.json') | |
| train_pairs = get_raw_pairs(train_data_path) | |
| test_data_path = os.path.join(args.dataset_dir, args.dataset, 'test.json') | |
| test_pairs = get_raw_pairs(test_data_path) | |
| train_data = MyDataset(args, train_pairs, src_lang, tgt_lang, is_train=True) | |
| train_sampler = DistributedSampler(train_data, shuffle=True) | |
| train_loader = DataLoader(dataset=train_data, \ | |
| batch_size=int(args.batch_size/args.nprocs), \ | |
| pin_memory=True, \ | |
| collate_fn=collater(args), \ | |
| num_workers=args.workers, \ | |
| sampler=train_sampler | |
| ) | |
| test_data = MyDataset(args, test_pairs, src_lang, tgt_lang, is_train=False) | |
| test_sampler = DistributedSampler(test_data, shuffle=False) | |
| test_loader = DataLoader(dataset=test_data, \ | |
| batch_size=1, \ | |
| pin_memory=True, \ | |
| collate_fn=collater(args), \ | |
| num_workers=args.workers, \ | |
| sampler=test_sampler | |
| ) | |
| return train_loader, train_sampler, test_loader, src_lang, tgt_lang | |