CoLMbo / load_data /prepare_dataloader.py
massabaali's picture
Upload CoLMbo model weights and code
f55a095 verified
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from preprocessing.ast_processor import ast
from load_data.data_collactor import DataCollatorWithPadding
def prepare_dataloader(dataset: Dataset, batch_size: int, valid_train_flag: str):
if valid_train_flag == "train":
data_collator = DataCollatorWithPadding(padding=True)
elif valid_train_flag == "valid":
data_collator = DataCollatorWithPadding(padding=True)
elif valid_train_flag == "test":
data_collator = DataCollatorWithPadding(padding=True)
return DataLoader(
dataset,
batch_size=batch_size,
pin_memory=True,
shuffle=False,
sampler=DistributedSampler(dataset),
collate_fn=data_collator
)