File size: 769 Bytes
c62c87b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from torch.utils.data import DataLoader
import torchvision

#get the correct transform for the effnet_b2 model
weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT
transform = weights.transforms()

#create test/train datasets and dataloaders
train_dir = "intel_image/seg_train"
test_dir = "intel_image/seg_test"

train_data = torchvision.datasets.ImageFolder(root = train_dir, transform = transform)
test_data = torchvision.datasets.ImageFolder(root = test_dir, transform = transform)

train_loader = DataLoader(train_data, shuffle = True, batch_size = 32)
test_loader = DataLoader(test_data, shuffle = False, batch_size = 32)

def create_dataloaders():
    """Returns: Training and test dataloaders """
    return train_loader, test_loader