import os from torchvision import datasets, transforms from torch.utils.data import DataLoader def get_data_loaders(data_dir="data/chest_xray", batch_size=4): transform = transforms.Compose([ transforms.Resize((128,128)), transforms.ToTensor(), transforms.Normalize([0.5],[0.5]) ]) train_dataset = datasets.ImageFolder( os.path.join(data_dir, "train"), transform=transform ) val_dataset = datasets.ImageFolder( os.path.join(data_dir, "val"), transform=transform ) test_dataset = datasets.ImageFolder( os.path.join(data_dir, "test"), transform=transform ) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0) val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=0) test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=0) return train_loader, val_loader, test_loader