from torch.utils.data import Dataset, DataLoader from torch.utils.data.distributed import DistributedSampler from preprocessing.ast_processor import ast from load_data.data_collactor import DataCollatorWithPadding def prepare_dataloader(dataset: Dataset, batch_size: int, valid_train_flag: str): if valid_train_flag == "train": data_collator = DataCollatorWithPadding(padding=True) elif valid_train_flag == "valid": data_collator = DataCollatorWithPadding(padding=True) elif valid_train_flag == "test": data_collator = DataCollatorWithPadding(padding=True) return DataLoader( dataset, batch_size=batch_size, pin_memory=True, shuffle=False, sampler=DistributedSampler(dataset), collate_fn=data_collator )