File size: 939 Bytes
0c7049d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torchvision

def make_dataloaders(train_ds,
                     test_ds,
                     batch_size: int):
  """Creates dataloaders
  Creates dataloaders by taking the directory in which train and test data are stored.

  Args:
  transforms(torchvision.transforms.Compose): Transform to apply to the dataset.

  Returns:
  tuple: train_dataloader, test_dataloader, class_names
  """


  train_dataloader = DataLoader(dataset = train_ds,
                                batch_size = batch_size,
                                num_workers = 1,
                                shuffle = True)

  test_dataloader = DataLoader(dataset = test_ds,
                                batch_size = batch_size,
                                num_workers = 1,
                                shuffle = False)

  return train_dataloader, test_dataloader