File size: 800 Bytes
f55a095
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from preprocessing.ast_processor import ast
from load_data.data_collactor import DataCollatorWithPadding

def prepare_dataloader(dataset: Dataset, batch_size: int, valid_train_flag: str):
    if valid_train_flag == "train":
        data_collator = DataCollatorWithPadding(padding=True)
    elif valid_train_flag == "valid":
        data_collator = DataCollatorWithPadding(padding=True)
    elif valid_train_flag == "test":
        data_collator = DataCollatorWithPadding(padding=True)
    return DataLoader(
        dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=False,
        sampler=DistributedSampler(dataset), 

        collate_fn=data_collator
    )