| """ | |
| Contains functionality for creating Pytorch DataLoaders for | |
| image classification data. | |
| """ | |
| import os | |
| from torchvision import datasets, transforms | |
| from torch.utils.data import DataLoader | |
| NUM_WORKERS = os.cpu_count() | |
| def create_dataloaders( | |
| train_dir: str, | |
| test_dir: str, | |
| transform: transforms.Compose, | |
| batch_size: int, | |
| num_workers: int=NUM_WORKERS | |
| ): | |
| """ | |
| Creates training and testing DataLoaders | |
| """ | |
| # Use ImageFolder to create Datasets | |
| train_data = datasets.ImageFolder(train_dir, transform=transform) | |
| test_data = datasets.ImageFolder(test_dir, transform=transform) | |
| # Get class names | |
| class_names = train_data.classes | |
| # Turn images into data loaders | |
| train_dataloader = DataLoader( | |
| train_data, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| num_workers=num_workers, | |
| pin_memory=True | |
| ) | |
| test_dataloader = DataLoader( | |
| test_data, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| num_workers=num_workers, | |
| pin_memory=True | |
| ) | |
| return train_dataloader, test_dataloader, class_names | |