File size: 4,801 Bytes
5762bbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets import load_dataset
from utils.config import load_config

config = load_config()
batch_size = config["batch_size"]
num_workers = config["num_workers"]
mean_nm = config["normalize_mean"]
std_nm = config["normalize_std"]
execute_remotely = config.get("execute_remotely", False)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# set dataset to clearml dataset if executing remotely or load from huggingface otherwise
if config["execute_remotely"]:
    from clearml import Dataset as ClearMLDataset
    clearml_dataset = ClearMLDataset.get(dataset_id="0c3de7af2d98482dacf41633a0587845")
    dataset_path = clearml_dataset.get_local_copy()
    dataset = load_dataset(dataset_path)
else:
    dataset = load_dataset("DScomp380/plant_village", cache_dir="./data_cache")
#split dataset into train(70%), and 30% remaining for val and test
splits = dataset["train"].train_test_split(test_size=0.30, seed=42)
train_split = splits["train"] #training set
remaining = splits["test"]

#split remaining 30% into val(15%) and test(15%)    
val_test = remaining.train_test_split(test_size=0.5, seed=42)
val_split = val_test["train"] #validation set
test_split = val_test["test"] #test set

preprocess_transform = transforms.Compose([
    # resize images to 224x224, convert to tensor, and normalize
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean_nm, std=std_nm)
])

def preprocess_batch(batch):
    batch["pixel_values"] = [preprocess_transform(img) for img in batch["image"]]
    return batch

if execute_remotely:
    def train_transform_batch(batch):
        batch["pixel_values"] = [preprocess_transform(img) for img in batch["image"]]
        return batch
    train_split = train_split.with_transform(train_transform_batch)
    val_split = val_split.with_transform(train_transform_batch)
    test_split = test_split.with_transform(train_transform_batch)
else:
    train_split = train_split.map(
        preprocess_batch,
        batched=True,
        batch_size=100,
        remove_columns=["image"],
        cache_file_name="./data_cache/train_preprocessed.arrow"
    )

    val_split = val_split.map(
        preprocess_batch,
        batched=True,
        batch_size=100,
        remove_columns=["image"],
        cache_file_name="./data_cache/val_preprocessed.arrow"
    )

    test_split = test_split.map(
        preprocess_batch,
        batched=True,
        batch_size=100,
        remove_columns=["image"],
        cache_file_name="./data_cache/test_preprocessed.arrow"
    )

    train_split.set_format(type="torch", columns=["pixel_values", "label"])
    val_split.set_format(type="torch", columns=["pixel_values", "label"])
    test_split.set_format(type="torch", columns=["pixel_values", "label"])

# augmentations
train_augment = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.3),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomApply([
        transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))
    ], p=0.3),
])


def train_collate_fn(batch):
    pixel_values = [item["pixel_values"] for item in batch]
    labels = [item["label"] for item in batch]
    
    augmented = [train_augment(img) for img in pixel_values] # apply augmentation while training
    
    return {
        "pixel_values": torch.stack(augmented),
        "labels": torch.tensor(labels)
    }

def val_collate_fn(batch):
    return {
        "pixel_values": torch.stack([item["pixel_values"] for item in batch]),
        "labels": torch.tensor([item["label"] for item in batch])
    }

# create DataLoaders for train, val, and test sets
train_loader = DataLoader(
    train_split, 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=num_workers, 
    pin_memory=True,
    persistent_workers=True if num_workers > 0 else False,
    collate_fn=train_collate_fn
)

val_loader = DataLoader(
    val_split, 
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=num_workers, 
    pin_memory=True,
    persistent_workers=True if num_workers > 0 else False,
    collate_fn=val_collate_fn
)

test_loader = DataLoader(
    test_split, 
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=num_workers, 
    pin_memory=True,
    persistent_workers=True if num_workers > 0 else False,
    collate_fn=val_collate_fn
)

if __name__ == "__main__":
    print(f"Device: {device}")
    print(f"Train samples: {len(train_split)}")
    print(f"Val samples: {len(val_split)}")
    print(f"Test samples: {len(test_split)}")
    print(f"Batches per epoch: {len(train_loader)}")