Spaces:
Runtime error
Runtime error
| from torch.utils.data import DataLoader | |
| import torchvision | |
| #get the correct transform for the effnet_b2 model | |
| weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT | |
| transform = weights.transforms() | |
| #create test/train datasets and dataloaders | |
| train_dir = "intel_image/seg_train" | |
| test_dir = "intel_image/seg_test" | |
| train_data = torchvision.datasets.ImageFolder(root = train_dir, transform = transform) | |
| test_data = torchvision.datasets.ImageFolder(root = test_dir, transform = transform) | |
| train_loader = DataLoader(train_data, shuffle = True, batch_size = 32) | |
| test_loader = DataLoader(test_data, shuffle = False, batch_size = 32) | |
| def create_dataloaders(): | |
| """Returns: Training and test dataloaders """ | |
| return train_loader, test_loader | |