Spaces:
Sleeping
Sleeping
File size: 3,713 Bytes
04cb886 2ace27a 83d4d7f 2ace27a 04cb886 83d4d7f 04cb886 2ace27a 04cb886 83d4d7f 04cb886 2ace27a 9dbc9de 2ace27a 04cb886 2ace27a 04cb886 78fbc90 04cb886 78fbc90 04cb886 2ace27a 04cb886 2ace27a 04cb886 78fbc90 04cb886 78fbc90 ed657fc 78fbc90 04cb886 ed657fc 04cb886 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
"""
A collection of data transformation and dataset loading functions.
"""
from torchvision import transforms
from torch.utils.data import DataLoader
# Standard ImageNet mean and std - Used to normalize the tensors
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
IMAGE_SIZE = (256, 256)
# Defines and returns the normalization pipeline.
def make_norm_pipeline():
# Pipeline ensures image format is consistent (for Val/Test)
normalisation = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
# Convert PIL Image to a PyTorch Tensor, scales pixel values from [0, 255] to [0.0, 1.0]
transforms.ToTensor(),
# Standardises pixel values
transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
])
return normalisation
# Defines and returns the augmentation (rotation, brightness, saturation, blur) pipeline.
def make_augment_pipeline(aug_config):
rotation = aug_config['rotation']
brightness = aug_config['brightness']
saturation = aug_config['saturation']
blur = aug_config['blur']
# Augmentation pipeline (to create "new" images by changing some parameters)
augmentation = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
# Randomly changing some parameters of pictures to enrich dataset
transforms.RandomRotation(rotation),
transforms.ColorJitter(brightness=brightness, saturation=saturation),
transforms.GaussianBlur(blur),
transforms.ToTensor(),
transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
])
return augmentation
def apply_augmentation(batch, augmentation):
batch['image'] = [augmentation(x) for x in batch['image']]
return batch
def apply_normalisation(batch, normalisation):
batch['image'] = [normalisation(x) for x in batch['image']]
return batch
"""
Creates and returns DataLoaders (train, val, test) for a given dataset.
Performs a 70/15/15 split
"""
def make_dataset_loaders(dataset, seed, batch_size, test_size, aug_config, workers=8):
# Define transformation pipelines for the dataset
normalisation = make_norm_pipeline()
augmentation = make_augment_pipeline(aug_config)
# 70/30 split creates train set
split_1 = dataset.train_test_split(test_size=test_size, seed=seed)
train_split = split_1['train']
remaining_split = split_1['test']
# 15/15 split on remaining data - validation and test sets
val_split = 0.5
split_2 = remaining_split.train_test_split(test_size=val_split, seed=seed)
val_split, test_split = split_2['train'], split_2['test']
# Put each split through pipelines
train_split.set_transform(lambda batch: apply_augmentation(batch, augmentation))
val_split.set_transform(lambda batch: apply_normalisation(batch, normalisation))
test_split.set_transform(lambda batch: apply_normalisation(batch, normalisation))
# Create dataloader for each
train_loader = DataLoader(
train_split,
batch_size=batch_size,
shuffle=True,
pin_memory=True,
num_workers=workers
)
val_loader = DataLoader(
val_split,
batch_size=batch_size,
shuffle=False,
pin_memory=True,
num_workers=workers
)
test_loader = DataLoader(
test_split,
batch_size=batch_size,
shuffle=False,
pin_memory=True,
num_workers=workers
)
class_names = dataset.features['label'].names
print(f"\nWorkers used in DataLoaders: {workers}\n")
dataset_loaders = {
"train": train_loader,
"val": val_loader,
"test": test_loader,
"classNames": class_names
}
return dataset_loaders
|