| from torchvision import datasets, transforms | |
| from torch.utils.data import DataLoader | |
| def get_dataloader(batch_size=16, img_size=128, data_path="./data"): | |
| transform = transforms.Compose([ | |
| transforms.Resize((img_size, img_size)), | |
| transforms.ToTensor() | |
| ]) | |
| dataset = datasets.ImageFolder(root=data_path, transform=transform) | |
| dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) | |
| return dataloader | |