Spaces:
Running
Running
| import os | |
| from torchvision import datasets, transforms | |
| from torch.utils.data import DataLoader | |
| def get_data_loaders(data_dir="data/chest_xray", batch_size=4): | |
| transform = transforms.Compose([ | |
| transforms.Resize((128,128)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5],[0.5]) | |
| ]) | |
| train_dataset = datasets.ImageFolder( | |
| os.path.join(data_dir, "train"), transform=transform | |
| ) | |
| val_dataset = datasets.ImageFolder( | |
| os.path.join(data_dir, "val"), transform=transform | |
| ) | |
| test_dataset = datasets.ImageFolder( | |
| os.path.join(data_dir, "test"), transform=transform | |
| ) | |
| train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0) | |
| val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=0) | |
| test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=0) | |
| return train_loader, val_loader, test_loader |