Spaces:
Sleeping
Sleeping
| from torchvision import datasets, transforms | |
| from torch.utils.data import DataLoader | |
| def get_data_loaders(data_dir, batch_size=32): | |
| # Data augmentation + normalization for training | |
| transform_train = transforms.Compose([ | |
| transforms.RandomResizedCrop(128), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.RandomRotation(10), | |
| transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) | |
| ]) | |
| # Only resize + normalize for validation | |
| transform_val = transforms.Compose([ | |
| transforms.Resize((128, 128)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) | |
| ]) | |
| train_dir = f"{data_dir}/training" | |
| val_dir = f"{data_dir}/validation" | |
| train_dataset = datasets.ImageFolder(train_dir, transform=transform_train) | |
| val_dataset = datasets.ImageFolder(val_dir, transform=transform_val) | |
| train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | |
| val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) | |
| return train_loader, val_loader, train_dataset.classes | |