smallGroupProject / dataPrep /helpers /transforms_loaders.py
Yusuf
per class accuracy
ed657fc
"""
A collection of data transformation and dataset loading functions.
"""
from torchvision import transforms
from torch.utils.data import DataLoader
# Standard ImageNet mean and std - Used to normalize the tensors
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
IMAGE_SIZE = (256, 256)
# Defines and returns the normalization pipeline.
def make_norm_pipeline():
# Pipeline ensures image format is consistent (for Val/Test)
normalisation = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
# Convert PIL Image to a PyTorch Tensor, scales pixel values from [0, 255] to [0.0, 1.0]
transforms.ToTensor(),
# Standardises pixel values
transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
])
return normalisation
# Defines and returns the augmentation (rotation, brightness, saturation, blur) pipeline.
def make_augment_pipeline(aug_config):
rotation = aug_config['rotation']
brightness = aug_config['brightness']
saturation = aug_config['saturation']
blur = aug_config['blur']
# Augmentation pipeline (to create "new" images by changing some parameters)
augmentation = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
# Randomly changing some parameters of pictures to enrich dataset
transforms.RandomRotation(rotation),
transforms.ColorJitter(brightness=brightness, saturation=saturation),
transforms.GaussianBlur(blur),
transforms.ToTensor(),
transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
])
return augmentation
def apply_augmentation(batch, augmentation):
batch['image'] = [augmentation(x) for x in batch['image']]
return batch
def apply_normalisation(batch, normalisation):
batch['image'] = [normalisation(x) for x in batch['image']]
return batch
"""
Creates and returns DataLoaders (train, val, test) for a given dataset.
Performs a 70/15/15 split
"""
def make_dataset_loaders(dataset, seed, batch_size, test_size, aug_config, workers=8):
# Define transformation pipelines for the dataset
normalisation = make_norm_pipeline()
augmentation = make_augment_pipeline(aug_config)
# 70/30 split creates train set
split_1 = dataset.train_test_split(test_size=test_size, seed=seed)
train_split = split_1['train']
remaining_split = split_1['test']
# 15/15 split on remaining data - validation and test sets
val_split = 0.5
split_2 = remaining_split.train_test_split(test_size=val_split, seed=seed)
val_split, test_split = split_2['train'], split_2['test']
# Put each split through pipelines
train_split.set_transform(lambda batch: apply_augmentation(batch, augmentation))
val_split.set_transform(lambda batch: apply_normalisation(batch, normalisation))
test_split.set_transform(lambda batch: apply_normalisation(batch, normalisation))
# Create dataloader for each
train_loader = DataLoader(
train_split,
batch_size=batch_size,
shuffle=True,
pin_memory=True,
num_workers=workers
)
val_loader = DataLoader(
val_split,
batch_size=batch_size,
shuffle=False,
pin_memory=True,
num_workers=workers
)
test_loader = DataLoader(
test_split,
batch_size=batch_size,
shuffle=False,
pin_memory=True,
num_workers=workers
)
class_names = dataset.features['label'].names
print(f"\nWorkers used in DataLoaders: {workers}\n")
dataset_loaders = {
"train": train_loader,
"val": val_loader,
"test": test_loader,
"classNames": class_names
}
return dataset_loaders