File size: 975 Bytes
2b74065
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

def get_data_loaders(data_dir="data/chest_xray", batch_size=4):

    transform = transforms.Compose([
        transforms.Resize((128,128)),
        transforms.ToTensor(),
        transforms.Normalize([0.5],[0.5])
    ])

    train_dataset = datasets.ImageFolder(
        os.path.join(data_dir, "train"), transform=transform
    )

    val_dataset = datasets.ImageFolder(
        os.path.join(data_dir, "val"), transform=transform
    )

    test_dataset = datasets.ImageFolder(
        os.path.join(data_dir, "test"), transform=transform
    )

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=0)

    return train_loader, val_loader, test_loader