smallGroupProject / dataPrep /data_preparation.py
k23064919's picture
Merge branch 'develop' of https://github.kcl.ac.uk/K23064919/smallGroupProject into develop
fcf6bb8
# --- Standard Python Library ---
import os
import random
# --- Data Handling & Analysis ---
import numpy as np
import pandas as pd
from datasets import load_dataset
from helpers.create_dataset import make_subset
from helpers.transforms_loaders import make_dataset_loaders
# --- Visualization ---
import matplotlib.pyplot as plt
# import seaborn as sns
# --- PyTorch (Machine Learning) ---
import torch
# --- Experiment Tracking ---
from clearml import Task
# -------- Controllable parameters --------
# Dataset parameters
SEED = 42
DATASET_LINK = "DScomp380/plant_village"
DATASET_SUBSET_RATIO = 0.25
# Augmentation parameters
ROTATION = 30
BRIGHTNESS = 0.2
SATURATION = 0.2
BLUR = 3
# DataLoader parameters
BATCH_SIZE = 32
TEST_SIZE = 0.3
# Setting up the SEED to be able to repeat experiments
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(SEED)
# ----- ClearML Setup -----
project_name = "Small Group Project"
task = Task.init(
project_name=f'{project_name}/Data Preparation',
task_name='Data Preparation',
task_type=Task.TaskTypes.data_processing
)
task.set_random_seed(SEED)
clearml_logger = task.get_logger()
# -------- Track full configuration in ClearML --------
task.connect({
"seed": SEED,
"dataset": {
"link": DATASET_LINK,
"subset_ratio": DATASET_SUBSET_RATIO,
},
"augmentation": {
"rotation": ROTATION,
"brightness": BRIGHTNESS,
"saturation": SATURATION,
"blur": BLUR
},
"dataloaders": {
"batch_size": BATCH_SIZE,
"test_size": TEST_SIZE
}
})
# ----- Load a subset from a given dataset & track with ClearML -----
data_plants, prototyping_dataset, features, clearml_dataset = make_subset(
DATASET_LINK, DATASET_SUBSET_RATIO, clearml_logger
)
# ---- Exploratory data analysis (EDA) ----
# Reformatting the label feature to understand bias
labels_list = prototyping_dataset['label']
df_labels = pd.Series(labels_list)
label_count = df_labels.value_counts(sort=False)
# Checking the amount of samples in each class and logging it to clearML
min_count = label_count.min()
clearml_logger.report_scalar(
title="Exploratory data analysis (EDA)",
series="Min Class Count",
value=min_count,
iteration=1
)
max_count = label_count.max()
clearml_logger.report_scalar(
title="Exploratory data analysis (EDA)",
series="Max Class Count",
value=max_count,
iteration=1
)
mean_count = label_count.mean()
clearml_logger.report_scalar(
title="Exploratory data analysis (EDA)",
series="Imbalance Ratio (Max/Min)",
value=(max_count / min_count),
iteration=1
)
print("--- Class imbalance analysis --- ")
print(f"Max labels in a class: {max_count}")
print(f"Min labels in a class: {min_count}")
print(f"Mean labels in a class: {mean_count}")
print(f"Imbalance ratio: {max_count/min_count:.2f}")
# Mapping indeces to class names
class_names = features['label'].names
formatted_class_names = [" ".join(name.replace('_', ' ').split()) for name in class_names]
label_count.index = formatted_class_names
plt.figure(figsize=(10,6))
label_count.plot(kind='bar', color='skyblue')
plt.title("Class Distribution in Prototype Dataset")
plt.xlabel("Class")
plt.ylabel("Count")
plt.tight_layout()
clearml_logger.report_matplotlib_figure(
title="EDA Class Distribution",
series="Prototype Subset",
figure=plt.gcf(),
iteration=1
)
# ----------------------------------------------------------------------
if __name__ == "__main__":
# ---------------- Dataset splits ----------------
aug_config = {
'rotation': ROTATION,
'brightness': BRIGHTNESS,
'saturation': SATURATION,
'blur': BLUR
}
prototype_loaders = make_dataset_loaders(
prototyping_dataset, SEED, BATCH_SIZE, TEST_SIZE, aug_config
)
print("\n--- Handoff Test Successful ---")
print(f"Prototype Train loader batches: {len(prototype_loaders['train'])}")
print(f"Prototype Validation loader batches: {len(prototype_loaders['val'])}")
print(f"Prototype Test loader batches: {len(prototype_loaders['test'])}")
clearml_logger.report_text(
f"Prototype loaders created: "
f"train={len(prototype_loaders['train'])}, "
f"val={len(prototype_loaders['val'])}, "
f"test={len(prototype_loaders['test'])}"
)
final_loaders = make_dataset_loaders(
data_plants, SEED, BATCH_SIZE, TEST_SIZE, aug_config
)
print("\n--- Handoff Test Successful ---")
print(f"Train loader batches: {len(final_loaders['train'])}")
print(f"Validation loader batches: {len(final_loaders['val'])}")
print(f"Test loader batches: {len(final_loaders['test'])}")
# Record dataset info in ClearML
task.connect_configuration(
{"dataset_id": clearml_dataset.id},
name="Dataset Metadata"
)
task.mark_completed()
# Close the ClearML task
task.close()
print("\n--- Script Finished ---")