Spaces:
Sleeping
Sleeping
ra1425
commited on
Commit
Β·
afc3315
1
Parent(s):
c0dc8ab
FEAT: Implemented the complete prototype data pipeline
Browse files- data_preparation.py +97 -11
data_preparation.py
CHANGED
|
@@ -1,16 +1,25 @@
|
|
|
|
|
| 1 |
import os
|
| 2 |
import random
|
|
|
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
import pandas as pd
|
|
|
|
| 5 |
|
| 6 |
-
#
|
| 7 |
-
#import seaborn as sns
|
| 8 |
import matplotlib.pyplot as plt
|
|
|
|
| 9 |
|
|
|
|
| 10 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
from clearml import Task, Logger
|
| 12 |
-
from datasets import load_dataset
|
| 13 |
|
|
|
|
| 14 |
SEED = 42
|
| 15 |
random.seed(SEED)
|
| 16 |
np.random.seed(SEED)
|
|
@@ -24,8 +33,10 @@ task = Task.init(project_name= 'smallGroupProject', task_name = 'data_prep')
|
|
| 24 |
task.set_random_seed(SEED)
|
| 25 |
clearml_logger = task.get_logger()
|
| 26 |
|
|
|
|
| 27 |
|
| 28 |
-
|
|
|
|
| 29 |
try:
|
| 30 |
ds = load_dataset("DScomp380/plant_village")
|
| 31 |
except Exception as e:
|
|
@@ -33,14 +44,18 @@ except Exception as e:
|
|
| 33 |
|
| 34 |
data_plants = ds['train']
|
| 35 |
|
|
|
|
| 36 |
# Verification
|
| 37 |
print(f"\nLoaded object type: {type(data_plants)}")
|
|
|
|
| 38 |
|
| 39 |
data_length = len(data_plants)
|
| 40 |
print(f"\nLoaded object size: {data_length}")
|
|
|
|
| 41 |
|
| 42 |
features = data_plants.features
|
| 43 |
print(f"\nDataset features: {features}")
|
|
|
|
| 44 |
|
| 45 |
# Verifying label count
|
| 46 |
if 'label' in features and hasattr(features['label'], 'num_classes'):
|
|
@@ -48,13 +63,19 @@ if 'label' in features and hasattr(features['label'], 'num_classes'):
|
|
| 48 |
print(f"Number of disease categories (labels): {label_count}")
|
| 49 |
else:
|
| 50 |
print("Couldnt determine the labels automatically")
|
|
|
|
|
|
|
| 51 |
|
| 52 |
# Verifying single sample
|
| 53 |
sample = data_plants[0]
|
| 54 |
print(f"Sample image type: {type(sample['image'])}")
|
| 55 |
print(f"Sample label: {sample['label']}")
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
-
|
|
|
|
| 58 |
# Creating the prototyping dataset
|
| 59 |
SUBSET_RATIO = 0.25 # 25% for prototyping
|
| 60 |
|
|
@@ -73,13 +94,11 @@ subset_indices = indices[:subset_size]
|
|
| 73 |
|
| 74 |
prototyping_dataset = data_plants.select(subset_indices)
|
| 75 |
|
|
|
|
| 76 |
#Verifying
|
| 77 |
print(f"Prototyping dataset size: {len(prototyping_dataset)}")
|
| 78 |
|
| 79 |
-
#
|
| 80 |
-
# Exploratory data analysis (EDA)
|
| 81 |
-
|
| 82 |
-
#sns.set(color_codes = True)
|
| 83 |
|
| 84 |
# Reformatting the label feature to understand bias
|
| 85 |
labels_list = prototyping_dataset['label']
|
|
@@ -112,11 +131,12 @@ clearml_logger.report_scalar(
|
|
| 112 |
iteration=1
|
| 113 |
)
|
| 114 |
|
| 115 |
-
print("Class imbalance analysis
|
| 116 |
print(f"Max labels in a class: {max_count}")
|
| 117 |
print(f"Min labels in a class: {min_count}")
|
| 118 |
print(f"Mean labels in a class: {mean_count}")
|
| 119 |
print(f"Imbalance ratio: {max_count/min_count:.2f}")
|
|
|
|
| 120 |
|
| 121 |
# Mapping indeces to class names
|
| 122 |
class_names = features['label'].names
|
|
@@ -139,4 +159,70 @@ clearml_logger.report_image(
|
|
| 139 |
local_path=plot_file # The path to the file you just saved
|
| 140 |
)
|
| 141 |
|
| 142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --- Standard Python Library ---
|
| 2 |
import os
|
| 3 |
import random
|
| 4 |
+
|
| 5 |
+
# --- Data Handling & Analysis ---
|
| 6 |
import numpy as np
|
| 7 |
import pandas as pd
|
| 8 |
+
from datasets import load_dataset
|
| 9 |
|
| 10 |
+
# --- Visualization ---
|
|
|
|
| 11 |
import matplotlib.pyplot as plt
|
| 12 |
+
# import seaborn as sns
|
| 13 |
|
| 14 |
+
# --- PyTorch (Machine Learning) ---
|
| 15 |
import torch
|
| 16 |
+
from torchvision import transforms
|
| 17 |
+
from torch.utils.data import DataLoader
|
| 18 |
+
|
| 19 |
+
# --- Experiment Tracking ---
|
| 20 |
from clearml import Task, Logger
|
|
|
|
| 21 |
|
| 22 |
+
# Setting up the SEED to be able to repeat experiments
|
| 23 |
SEED = 42
|
| 24 |
random.seed(SEED)
|
| 25 |
np.random.seed(SEED)
|
|
|
|
| 33 |
task.set_random_seed(SEED)
|
| 34 |
clearml_logger = task.get_logger()
|
| 35 |
|
| 36 |
+
print("β
Checkpoint: Imports, SEED and ClearML are set")
|
| 37 |
|
| 38 |
+
|
| 39 |
+
# Loading dataset from HugginFace and checking it
|
| 40 |
try:
|
| 41 |
ds = load_dataset("DScomp380/plant_village")
|
| 42 |
except Exception as e:
|
|
|
|
| 44 |
|
| 45 |
data_plants = ds['train']
|
| 46 |
|
| 47 |
+
print("--- Verification ---")
|
| 48 |
# Verification
|
| 49 |
print(f"\nLoaded object type: {type(data_plants)}")
|
| 50 |
+
print("\n --- \n")
|
| 51 |
|
| 52 |
data_length = len(data_plants)
|
| 53 |
print(f"\nLoaded object size: {data_length}")
|
| 54 |
+
print("\n --- \n")
|
| 55 |
|
| 56 |
features = data_plants.features
|
| 57 |
print(f"\nDataset features: {features}")
|
| 58 |
+
print("\n --- \n")
|
| 59 |
|
| 60 |
# Verifying label count
|
| 61 |
if 'label' in features and hasattr(features['label'], 'num_classes'):
|
|
|
|
| 63 |
print(f"Number of disease categories (labels): {label_count}")
|
| 64 |
else:
|
| 65 |
print("Couldnt determine the labels automatically")
|
| 66 |
+
print("\n --- \n")
|
| 67 |
+
|
| 68 |
|
| 69 |
# Verifying single sample
|
| 70 |
sample = data_plants[0]
|
| 71 |
print(f"Sample image type: {type(sample['image'])}")
|
| 72 |
print(f"Sample label: {sample['label']}")
|
| 73 |
+
print("\n --- \n")
|
| 74 |
+
|
| 75 |
+
print("β
Checkpoint: Dataset is loaded and data is checked")
|
| 76 |
|
| 77 |
+
|
| 78 |
+
# --------------------------- Data selection --------------------------------
|
| 79 |
# Creating the prototyping dataset
|
| 80 |
SUBSET_RATIO = 0.25 # 25% for prototyping
|
| 81 |
|
|
|
|
| 94 |
|
| 95 |
prototyping_dataset = data_plants.select(subset_indices)
|
| 96 |
|
| 97 |
+
print("β
Checkpoint: Prototyping dataset is created")
|
| 98 |
#Verifying
|
| 99 |
print(f"Prototyping dataset size: {len(prototyping_dataset)}")
|
| 100 |
|
| 101 |
+
# ---- Exploratory data analysis (EDA) ----
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
# Reformatting the label feature to understand bias
|
| 104 |
labels_list = prototyping_dataset['label']
|
|
|
|
| 131 |
iteration=1
|
| 132 |
)
|
| 133 |
|
| 134 |
+
print("--- Class imbalance analysis --- ")
|
| 135 |
print(f"Max labels in a class: {max_count}")
|
| 136 |
print(f"Min labels in a class: {min_count}")
|
| 137 |
print(f"Mean labels in a class: {mean_count}")
|
| 138 |
print(f"Imbalance ratio: {max_count/min_count:.2f}")
|
| 139 |
+
print("β
Checkpoint: Class distribution is calculated")
|
| 140 |
|
| 141 |
# Mapping indeces to class names
|
| 142 |
class_names = features['label'].names
|
|
|
|
| 159 |
local_path=plot_file # The path to the file you just saved
|
| 160 |
)
|
| 161 |
|
| 162 |
+
# To see the plot uncomment but itll pause the code
|
| 163 |
+
#plt.show()
|
| 164 |
+
print("β
Checkpoint: Plot with classes distributions is created and saved")
|
| 165 |
+
# --------------- Data Splits ------------
|
| 166 |
+
|
| 167 |
+
# Standard ImageNet mean and std
|
| 168 |
+
# These values are used to normalize the tensors
|
| 169 |
+
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
| 170 |
+
IMAGENET_STD = [0.229, 0.224, 0.225]
|
| 171 |
+
|
| 172 |
+
# Defining pipeline to ensure that images are consistently formatted (for Val/Test)
|
| 173 |
+
normalisation_pipeline = transforms.Compose([
|
| 174 |
+
# Convert PIL Image to a PyTorch Tensor
|
| 175 |
+
# This also scales pixel values from [0, 255] to [0.0, 1.0]
|
| 176 |
+
transforms.ToTensor(),
|
| 177 |
+
|
| 178 |
+
# Normalise the Tensor; Standartises pixel values
|
| 179 |
+
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
|
| 180 |
+
])
|
| 181 |
+
print("β
Checkpoint: Transform pipeline created")
|
| 182 |
+
|
| 183 |
+
# Augmentation pipeline (to change some parameters of the pictures to create "new" ones)
|
| 184 |
+
augmentation_pipeline = transforms.Compose([
|
| 185 |
+
# Randomly changing some parameters of pictures to enrich dataset
|
| 186 |
+
transforms.RandomRotation(degrees=30),
|
| 187 |
+
transforms.ColorJitter(brightness=0.2, saturation=0.2),
|
| 188 |
+
transforms.GaussianBlur(kernel_size=3),
|
| 189 |
+
|
| 190 |
+
# Convert to Tensor and Normalise
|
| 191 |
+
transforms.ToTensor(),
|
| 192 |
+
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
|
| 193 |
+
])
|
| 194 |
+
|
| 195 |
+
print("β
Checkpoint: Augmentation pipeline created")
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# -- Split the prototype dataset --
|
| 199 |
+
# This returns a dictionary: {'train': 70%, 'test': 30%}
|
| 200 |
+
split_1_dict = prototyping_dataset.train_test_split(test_size=0.3, seed=SEED)
|
| 201 |
+
|
| 202 |
+
# Assign the 70% part to final train split
|
| 203 |
+
proto_train_split = split_1_dict['train']
|
| 204 |
+
|
| 205 |
+
# Assign the 30% part to a temporary var
|
| 206 |
+
proto_temp_split = split_1_dict['test']
|
| 207 |
+
|
| 208 |
+
# Split 30% into 2 15%
|
| 209 |
+
# This returns a dictionary: {'train': 50%, 'test': 50%}
|
| 210 |
+
split_2_dict = proto_temp_split.train_test_split(test_size=0.5, seed=SEED)
|
| 211 |
+
|
| 212 |
+
proto_val_split = split_2_dict['train']
|
| 213 |
+
proto_test_split = split_2_dict['test']
|
| 214 |
+
|
| 215 |
+
print("β
Checkpoint: Dataset splitted")
|
| 216 |
+
|
| 217 |
+
# -- Putting splits through pipelines --
|
| 218 |
+
proto_train_split.set_transform(augmentation_pipeline)
|
| 219 |
+
proto_val_split.set_transform(normalisation_pipeline)
|
| 220 |
+
proto_test_split.set_transform(normalisation_pipeline)
|
| 221 |
+
|
| 222 |
+
# -- Creating the prototype dataloaders --
|
| 223 |
+
BATCH_SIZE = 32
|
| 224 |
+
proto_train_loader = DataLoader(dataset = proto_train_split, batch_size = BATCH_SIZE, shuffle = True )
|
| 225 |
+
proto_val_loader = DataLoader(dataset = proto_val_split, batch_size = BATCH_SIZE, shuffle = False )
|
| 226 |
+
proto_test_loader = DataLoader(dataset = proto_test_split, batch_size = BATCH_SIZE, shuffle = False )
|
| 227 |
+
|
| 228 |
+
print("β
Checkpoint: DataLoaders are set")
|