""" A collection of data transformation and dataset loading functions. """ from torchvision import transforms from torch.utils.data import DataLoader # Standard ImageNet mean and std - Used to normalize the tensors IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] IMAGE_SIZE = (256, 256) # Defines and returns the normalization pipeline. def make_norm_pipeline(): # Pipeline ensures image format is consistent (for Val/Test) normalisation = transforms.Compose([ transforms.Resize(IMAGE_SIZE), # Convert PIL Image to a PyTorch Tensor, scales pixel values from [0, 255] to [0.0, 1.0] transforms.ToTensor(), # Standardises pixel values transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD) ]) return normalisation # Defines and returns the augmentation (rotation, brightness, saturation, blur) pipeline. def make_augment_pipeline(aug_config): rotation = aug_config['rotation'] brightness = aug_config['brightness'] saturation = aug_config['saturation'] blur = aug_config['blur'] # Augmentation pipeline (to create "new" images by changing some parameters) augmentation = transforms.Compose([ transforms.Resize(IMAGE_SIZE), # Randomly changing some parameters of pictures to enrich dataset transforms.RandomRotation(rotation), transforms.ColorJitter(brightness=brightness, saturation=saturation), transforms.GaussianBlur(blur), transforms.ToTensor(), transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD) ]) return augmentation def apply_augmentation(batch, augmentation): batch['image'] = [augmentation(x) for x in batch['image']] return batch def apply_normalisation(batch, normalisation): batch['image'] = [normalisation(x) for x in batch['image']] return batch """ Creates and returns DataLoaders (train, val, test) for a given dataset. Performs a 70/15/15 split """ def make_dataset_loaders(dataset, seed, batch_size, test_size, aug_config, workers=8): # Define transformation pipelines for the dataset normalisation = make_norm_pipeline() augmentation = make_augment_pipeline(aug_config) # 70/30 split creates train set split_1 = dataset.train_test_split(test_size=test_size, seed=seed) train_split = split_1['train'] remaining_split = split_1['test'] # 15/15 split on remaining data - validation and test sets val_split = 0.5 split_2 = remaining_split.train_test_split(test_size=val_split, seed=seed) val_split, test_split = split_2['train'], split_2['test'] # Put each split through pipelines train_split.set_transform(lambda batch: apply_augmentation(batch, augmentation)) val_split.set_transform(lambda batch: apply_normalisation(batch, normalisation)) test_split.set_transform(lambda batch: apply_normalisation(batch, normalisation)) # Create dataloader for each train_loader = DataLoader( train_split, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=workers ) val_loader = DataLoader( val_split, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=workers ) test_loader = DataLoader( test_split, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=workers ) class_names = dataset.features['label'].names print(f"\nWorkers used in DataLoaders: {workers}\n") dataset_loaders = { "train": train_loader, "val": val_loader, "test": test_loader, "classNames": class_names } return dataset_loaders