| import torch | |
| from torch.utils.data import DataLoader | |
| from torchvision import datasets | |
| def create_datasets(train_dir, test_dir, data_transform): | |
| train_data = datasets.ImageFolder(root=train_dir, | |
| transform=data_transform, | |
| target_transform=None) | |
| test_data = datasets.ImageFolder(root=test_dir, | |
| transform=data_transform) | |
| return train_data, test_data | |
| def create_dataloaders(train_dataset, test_dataset, batch_size, num_workers): | |
| train_dataloader = DataLoader(dataset=train_dataset, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| shuffle=True) | |
| test_dataloader = DataLoader(dataset=test_dataset, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| shuffle=False) | |
| return train_dataloader, test_dataloader | |
| def data_setup(train_dir, test_dir, data_transform, batch_size, num_workers): | |
| train_dataset, test_dataset = create_datasets(train_dir=train_dir, | |
| test_dir=test_dir, | |
| data_transform=data_transform) | |
| class_names = train_dataset.classes | |
| train_dataloader, test_dataloader = create_dataloaders(train_dataset=train_dataset, | |
| test_dataset=test_dataset, | |
| batch_size=batch_size, | |
| num_workers=num_workers) | |
| return train_dataloader, test_dataloader, class_names |