Spaces:
Runtime error
Runtime error
File size: 939 Bytes
0c7049d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torchvision
def make_dataloaders(train_ds,
test_ds,
batch_size: int):
"""Creates dataloaders
Creates dataloaders by taking the directory in which train and test data are stored.
Args:
transforms(torchvision.transforms.Compose): Transform to apply to the dataset.
Returns:
tuple: train_dataloader, test_dataloader, class_names
"""
train_dataloader = DataLoader(dataset = train_ds,
batch_size = batch_size,
num_workers = 1,
shuffle = True)
test_dataloader = DataLoader(dataset = test_ds,
batch_size = batch_size,
num_workers = 1,
shuffle = False)
return train_dataloader, test_dataloader
|