from torchvision import datasets, transforms from torch.utils.data import DataLoader def get_dataloader(batch_size=16, img_size=128, data_path="./data"): transform = transforms.Compose([ transforms.Resize((img_size, img_size)), transforms.ToTensor() ]) dataset = datasets.ImageFolder(root=data_path, transform=transform) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) return dataloader