File size: 3,713 Bytes
04cb886
 
 
 
 
 
 
 
2ace27a
 
 
83d4d7f
2ace27a
 
04cb886
 
 
83d4d7f
04cb886
 
 
 
 
 
 
2ace27a
 
 
 
 
 
 
 
 
 
04cb886
 
83d4d7f
04cb886
2ace27a
9dbc9de
2ace27a
04cb886
 
 
 
2ace27a
04cb886
 
78fbc90
 
 
 
 
 
 
 
 
04cb886
 
 
 
78fbc90
04cb886
 
2ace27a
 
04cb886
 
 
 
 
 
 
2ace27a
04cb886
 
 
 
78fbc90
 
 
04cb886
 
78fbc90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed657fc
78fbc90
 
04cb886
 
 
 
ed657fc
 
04cb886
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
A collection of data transformation and dataset loading functions.
"""

from torchvision import transforms
from torch.utils.data import DataLoader


# Standard ImageNet mean and std - Used to normalize the tensors
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
IMAGE_SIZE = (256, 256)
# Defines and returns the normalization pipeline.
def make_norm_pipeline():

    # Pipeline ensures image format is consistent (for Val/Test)
    normalisation = transforms.Compose([
        transforms.Resize(IMAGE_SIZE), 
        # Convert PIL Image to a PyTorch Tensor, scales pixel values from [0, 255] to [0.0, 1.0]
        transforms.ToTensor(),

        # Standardises pixel values
        transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
    ])

    return normalisation

# Defines and returns the augmentation (rotation, brightness, saturation, blur) pipeline.
def make_augment_pipeline(aug_config):

    rotation = aug_config['rotation']
    brightness = aug_config['brightness']
    saturation = aug_config['saturation']
    blur = aug_config['blur']

    # Augmentation pipeline (to create "new" images by changing some parameters)
    augmentation = transforms.Compose([
        transforms.Resize(IMAGE_SIZE),
        # Randomly changing some parameters of pictures to enrich dataset
        transforms.RandomRotation(rotation),
        transforms.ColorJitter(brightness=brightness, saturation=saturation),
        transforms.GaussianBlur(blur),
        transforms.ToTensor(),
        transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
    ])

    return augmentation


def apply_augmentation(batch, augmentation):
    batch['image'] = [augmentation(x) for x in batch['image']]
    return batch

def apply_normalisation(batch, normalisation):
    batch['image'] = [normalisation(x) for x in batch['image']]
    return batch
    

"""
Creates and returns DataLoaders (train, val, test) for a given dataset.
Performs a 70/15/15 split
"""
def make_dataset_loaders(dataset, seed, batch_size, test_size, aug_config, workers=8):

    # Define transformation pipelines for the dataset
    normalisation = make_norm_pipeline()
    augmentation = make_augment_pipeline(aug_config)

    # 70/30 split creates train set
    split_1 = dataset.train_test_split(test_size=test_size, seed=seed)
    train_split = split_1['train']
    remaining_split = split_1['test']

    # 15/15 split on remaining data - validation and test sets
    val_split = 0.5
    split_2 = remaining_split.train_test_split(test_size=val_split, seed=seed)
    val_split, test_split = split_2['train'], split_2['test']

    # Put each split through pipelines
    train_split.set_transform(lambda batch: apply_augmentation(batch, augmentation))
    val_split.set_transform(lambda batch: apply_normalisation(batch, normalisation))
    test_split.set_transform(lambda batch: apply_normalisation(batch, normalisation))

    # Create dataloader for each
    train_loader = DataLoader(
        train_split,
        batch_size=batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=workers
    )
    val_loader = DataLoader(
        val_split,
        batch_size=batch_size,
        shuffle=False,
        pin_memory=True,
        num_workers=workers
    )
    test_loader = DataLoader(
        test_split,
        batch_size=batch_size,
        shuffle=False,
        pin_memory=True,
        num_workers=workers
    )
    class_names = dataset.features['label'].names

    print(f"\nWorkers used in DataLoaders: {workers}\n")

    dataset_loaders = {
        "train": train_loader,
        "val": val_loader,
        "test": test_loader,
        "classNames": class_names
    }

    return dataset_loaders