Spaces:
Sleeping
Sleeping
| """ | |
| A collection of data transformation and dataset loading functions. | |
| """ | |
| from torchvision import transforms | |
| from torch.utils.data import DataLoader | |
| # Standard ImageNet mean and std - Used to normalize the tensors | |
| IMAGENET_MEAN = [0.485, 0.456, 0.406] | |
| IMAGENET_STD = [0.229, 0.224, 0.225] | |
| IMAGE_SIZE = (256, 256) | |
| # Defines and returns the normalization pipeline. | |
| def make_norm_pipeline(): | |
| # Pipeline ensures image format is consistent (for Val/Test) | |
| normalisation = transforms.Compose([ | |
| transforms.Resize(IMAGE_SIZE), | |
| # Convert PIL Image to a PyTorch Tensor, scales pixel values from [0, 255] to [0.0, 1.0] | |
| transforms.ToTensor(), | |
| # Standardises pixel values | |
| transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD) | |
| ]) | |
| return normalisation | |
| # Defines and returns the augmentation (rotation, brightness, saturation, blur) pipeline. | |
| def make_augment_pipeline(aug_config): | |
| rotation = aug_config['rotation'] | |
| brightness = aug_config['brightness'] | |
| saturation = aug_config['saturation'] | |
| blur = aug_config['blur'] | |
| # Augmentation pipeline (to create "new" images by changing some parameters) | |
| augmentation = transforms.Compose([ | |
| transforms.Resize(IMAGE_SIZE), | |
| # Randomly changing some parameters of pictures to enrich dataset | |
| transforms.RandomRotation(rotation), | |
| transforms.ColorJitter(brightness=brightness, saturation=saturation), | |
| transforms.GaussianBlur(blur), | |
| transforms.ToTensor(), | |
| transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD) | |
| ]) | |
| return augmentation | |
| def apply_augmentation(batch, augmentation): | |
| batch['image'] = [augmentation(x) for x in batch['image']] | |
| return batch | |
| def apply_normalisation(batch, normalisation): | |
| batch['image'] = [normalisation(x) for x in batch['image']] | |
| return batch | |
| """ | |
| Creates and returns DataLoaders (train, val, test) for a given dataset. | |
| Performs a 70/15/15 split | |
| """ | |
| def make_dataset_loaders(dataset, seed, batch_size, test_size, aug_config, workers=8): | |
| # Define transformation pipelines for the dataset | |
| normalisation = make_norm_pipeline() | |
| augmentation = make_augment_pipeline(aug_config) | |
| # 70/30 split creates train set | |
| split_1 = dataset.train_test_split(test_size=test_size, seed=seed) | |
| train_split = split_1['train'] | |
| remaining_split = split_1['test'] | |
| # 15/15 split on remaining data - validation and test sets | |
| val_split = 0.5 | |
| split_2 = remaining_split.train_test_split(test_size=val_split, seed=seed) | |
| val_split, test_split = split_2['train'], split_2['test'] | |
| # Put each split through pipelines | |
| train_split.set_transform(lambda batch: apply_augmentation(batch, augmentation)) | |
| val_split.set_transform(lambda batch: apply_normalisation(batch, normalisation)) | |
| test_split.set_transform(lambda batch: apply_normalisation(batch, normalisation)) | |
| # Create dataloader for each | |
| train_loader = DataLoader( | |
| train_split, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| pin_memory=True, | |
| num_workers=workers | |
| ) | |
| val_loader = DataLoader( | |
| val_split, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| pin_memory=True, | |
| num_workers=workers | |
| ) | |
| test_loader = DataLoader( | |
| test_split, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| pin_memory=True, | |
| num_workers=workers | |
| ) | |
| class_names = dataset.features['label'].names | |
| print(f"\nWorkers used in DataLoaders: {workers}\n") | |
| dataset_loaders = { | |
| "train": train_loader, | |
| "val": val_loader, | |
| "test": test_loader, | |
| "classNames": class_names | |
| } | |
| return dataset_loaders | |