File size: 5,056 Bytes
afc3315
9af0f61
71353f6
afc3315
 
71353f6
c0dc8ab
afc3315
2ace27a
04cb886
c0dc8ab
afc3315
c0dc8ab
afc3315
c0dc8ab
afc3315
71353f6
afc3315
 
2ace27a
71353f6
6f31a0a
2ace27a
 
71353f6
2ace27a
04cb886
 
2ace27a
 
 
 
 
 
 
 
 
 
 
71353f6
 
 
 
 
 
 
04cb886
3f67469
9dbc9de
3f67469
9dbc9de
 
 
71353f6
c0dc8ab
 
2fd4542
 
9dbc9de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ace27a
685281d
 
2ace27a
 
c0dc8ab
afc3315
c0dc8ab
 
685281d
c0dc8ab
2ace27a
c0dc8ab
 
 
 
 
89c8841
c0dc8ab
 
 
 
 
 
 
89c8841
c0dc8ab
 
 
 
 
 
 
89c8841
c0dc8ab
 
 
 
afc3315
c0dc8ab
 
 
 
 
 
 
 
 
 
6b1327e
 
685281d
6b1327e
 
 
 
9dbc9de
6b1327e
685281d
9dbc9de
6b1327e
 
 
 
42b46e3
 
 
2fd4542
2ace27a
 
 
 
 
 
 
04cb886
685281d
04cb886
 
42b46e3
04cb886
 
 
 
9dbc9de
 
 
 
 
 
 
04cb886
2ace27a
04cb886
89c8841
 
04cb886
 
 
6b1327e
685281d
 
 
 
 
 
9dbc9de
685281d
42b46e3
 
89c8841
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
# --- Standard Python Library ---
import os
import random

# --- Data Handling & Analysis ---
import numpy as np
import pandas as pd
from datasets import load_dataset
from helpers.create_dataset import make_subset
from helpers.transforms_loaders import make_dataset_loaders

# --- Visualization ---
import matplotlib.pyplot as plt
# import seaborn as sns

# --- PyTorch (Machine Learning) ---
import torch

# --- Experiment Tracking ---
from clearml import Task


# -------- Controllable parameters --------
# Dataset parameters
SEED = 42
DATASET_LINK = "DScomp380/plant_village"
DATASET_SUBSET_RATIO = 0.25

# Augmentation parameters
ROTATION = 30
BRIGHTNESS = 0.2
SATURATION = 0.2
BLUR = 3

# DataLoader parameters
BATCH_SIZE = 32
TEST_SIZE = 0.3

# Setting up the SEED to be able to repeat experiments
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)


# ----- ClearML Setup -----
project_name = "Small Group Project"
task = Task.init(
    project_name=f'{project_name}/Data Preparation',
    task_name='Data Preparation',
    task_type=Task.TaskTypes.data_processing
)
task.set_random_seed(SEED)
clearml_logger = task.get_logger()


# -------- Track full configuration in ClearML --------
task.connect({
    "seed": SEED,
    "dataset": {
        "link": DATASET_LINK,
        "subset_ratio": DATASET_SUBSET_RATIO,
    },
    "augmentation": {
        "rotation": ROTATION,
        "brightness": BRIGHTNESS,
        "saturation": SATURATION,
        "blur": BLUR
    },
    "dataloaders": {
        "batch_size": BATCH_SIZE,
        "test_size": TEST_SIZE
    }
})

# ----- Load a subset from a given dataset & track with ClearML -----
data_plants, prototyping_dataset, features, clearml_dataset = make_subset(
    DATASET_LINK, DATASET_SUBSET_RATIO, clearml_logger
)


# ---- Exploratory data analysis (EDA) ----

# Reformatting the label feature to understand bias
labels_list = prototyping_dataset['label']
df_labels = pd.Series(labels_list)
label_count = df_labels.value_counts(sort=False)

# Checking the amount of samples in each class and logging it to clearML

min_count = label_count.min()
clearml_logger.report_scalar(
    title="Exploratory data analysis (EDA)",
    series="Min Class Count", 
    value=min_count, 
    iteration=1
)

max_count = label_count.max()
clearml_logger.report_scalar(
    title="Exploratory data analysis (EDA)",
    series="Max Class Count", 
    value=max_count, 
    iteration=1
)

mean_count = label_count.mean()
clearml_logger.report_scalar(
    title="Exploratory data analysis (EDA)",
    series="Imbalance Ratio (Max/Min)", 
    value=(max_count / min_count), 
    iteration=1
)
print("--- Class imbalance analysis --- ")
print(f"Max labels in a class: {max_count}")
print(f"Min labels in a class: {min_count}")
print(f"Mean labels in a class: {mean_count}")
print(f"Imbalance ratio: {max_count/min_count:.2f}")

# Mapping indeces to class names
class_names = features['label'].names
formatted_class_names = [" ".join(name.replace('_', ' ').split()) for name in class_names]
label_count.index = formatted_class_names

plt.figure(figsize=(10,6))
label_count.plot(kind='bar', color='skyblue')
plt.title("Class Distribution in Prototype Dataset")
plt.xlabel("Class")
plt.ylabel("Count")
plt.tight_layout()

clearml_logger.report_matplotlib_figure(
    title="EDA Class Distribution",
    series="Prototype Subset",
    figure=plt.gcf(),
    iteration=1
)


# ----------------------------------------------------------------------
if __name__ == "__main__":
    
    # ---------------- Dataset splits ----------------
    aug_config = {
        'rotation': ROTATION,
        'brightness': BRIGHTNESS,
        'saturation': SATURATION,
        'blur': BLUR
    }

    prototype_loaders = make_dataset_loaders(
        prototyping_dataset, SEED, BATCH_SIZE, TEST_SIZE, aug_config
    )

    print("\n--- Handoff Test Successful ---")
    print(f"Prototype Train loader batches: {len(prototype_loaders['train'])}")
    print(f"Prototype Validation loader batches: {len(prototype_loaders['val'])}")
    print(f"Prototype Test loader batches: {len(prototype_loaders['test'])}")

    clearml_logger.report_text(
        f"Prototype loaders created: "
        f"train={len(prototype_loaders['train'])}, "
        f"val={len(prototype_loaders['val'])}, "
        f"test={len(prototype_loaders['test'])}"
    )

    final_loaders = make_dataset_loaders(
        data_plants, SEED, BATCH_SIZE, TEST_SIZE, aug_config
    )

    print("\n--- Handoff Test Successful ---")
    print(f"Train loader batches: {len(final_loaders['train'])}")
    print(f"Validation loader batches: {len(final_loaders['val'])}")
    print(f"Test loader batches: {len(final_loaders['test'])}")

    # Record dataset info in ClearML
    task.connect_configuration(
        {"dataset_id": clearml_dataset.id},
        name="Dataset Metadata"
    )
    task.mark_completed()

    
    # Close the ClearML task 
    task.close()
    print("\n--- Script Finished ---")