Spaces:
Sleeping
Sleeping
| import os | |
| from torchvision import datasets, transforms | |
| from torch.utils.data import DataLoader, random_split | |
| def get_dataloaders(data_dir, batch_size=32, val_split=0.2): | |
| transform = transforms.Compose([ | |
| transforms.Resize((128, 128)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5]*3, [0.5]*3) | |
| ]) | |
| dataset = datasets.ImageFolder(data_dir, transform=transform) | |
| val_len = int(len(dataset) * val_split) | |
| train_len = len(dataset) - val_len | |
| train_set, val_set = random_split(dataset, [train_len, val_len]) | |
| train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) | |
| val_loader = DataLoader(val_set, batch_size=batch_size) | |
| return train_loader, val_loader, dataset.classes | |