ra1425 commited on
Commit
afc3315
Β·
1 Parent(s): c0dc8ab

FEAT: Implemented the complete prototype data pipeline

Browse files
Files changed (1) hide show
  1. data_preparation.py +97 -11
data_preparation.py CHANGED
@@ -1,16 +1,25 @@
 
1
  import os
2
  import random
 
 
3
  import numpy as np
4
  import pandas as pd
 
5
 
6
- # Visualisation
7
- #import seaborn as sns
8
  import matplotlib.pyplot as plt
 
9
 
 
10
  import torch
 
 
 
 
11
  from clearml import Task, Logger
12
- from datasets import load_dataset
13
 
 
14
  SEED = 42
15
  random.seed(SEED)
16
  np.random.seed(SEED)
@@ -24,8 +33,10 @@ task = Task.init(project_name= 'smallGroupProject', task_name = 'data_prep')
24
  task.set_random_seed(SEED)
25
  clearml_logger = task.get_logger()
26
 
 
27
 
28
- # Loading dataset from HugginFace
 
29
  try:
30
  ds = load_dataset("DScomp380/plant_village")
31
  except Exception as e:
@@ -33,14 +44,18 @@ except Exception as e:
33
 
34
  data_plants = ds['train']
35
 
 
36
  # Verification
37
  print(f"\nLoaded object type: {type(data_plants)}")
 
38
 
39
  data_length = len(data_plants)
40
  print(f"\nLoaded object size: {data_length}")
 
41
 
42
  features = data_plants.features
43
  print(f"\nDataset features: {features}")
 
44
 
45
  # Verifying label count
46
  if 'label' in features and hasattr(features['label'], 'num_classes'):
@@ -48,13 +63,19 @@ if 'label' in features and hasattr(features['label'], 'num_classes'):
48
  print(f"Number of disease categories (labels): {label_count}")
49
  else:
50
  print("Couldnt determine the labels automatically")
 
 
51
 
52
  # Verifying single sample
53
  sample = data_plants[0]
54
  print(f"Sample image type: {type(sample['image'])}")
55
  print(f"Sample label: {sample['label']}")
 
 
 
56
 
57
- # -----------------------------------------------------------
 
58
  # Creating the prototyping dataset
59
  SUBSET_RATIO = 0.25 # 25% for prototyping
60
 
@@ -73,13 +94,11 @@ subset_indices = indices[:subset_size]
73
 
74
  prototyping_dataset = data_plants.select(subset_indices)
75
 
 
76
  #Verifying
77
  print(f"Prototyping dataset size: {len(prototyping_dataset)}")
78
 
79
- # -----------------------------------------------------------
80
- # Exploratory data analysis (EDA)
81
-
82
- #sns.set(color_codes = True)
83
 
84
  # Reformatting the label feature to understand bias
85
  labels_list = prototyping_dataset['label']
@@ -112,11 +131,12 @@ clearml_logger.report_scalar(
112
  iteration=1
113
  )
114
 
115
- print("Class imbalance analysis: ")
116
  print(f"Max labels in a class: {max_count}")
117
  print(f"Min labels in a class: {min_count}")
118
  print(f"Mean labels in a class: {mean_count}")
119
  print(f"Imbalance ratio: {max_count/min_count:.2f}")
 
120
 
121
  # Mapping indeces to class names
122
  class_names = features['label'].names
@@ -139,4 +159,70 @@ clearml_logger.report_image(
139
  local_path=plot_file # The path to the file you just saved
140
  )
141
 
142
- plt.show()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --- Standard Python Library ---
2
  import os
3
  import random
4
+
5
+ # --- Data Handling & Analysis ---
6
  import numpy as np
7
  import pandas as pd
8
+ from datasets import load_dataset
9
 
10
+ # --- Visualization ---
 
11
  import matplotlib.pyplot as plt
12
+ # import seaborn as sns
13
 
14
+ # --- PyTorch (Machine Learning) ---
15
  import torch
16
+ from torchvision import transforms
17
+ from torch.utils.data import DataLoader
18
+
19
+ # --- Experiment Tracking ---
20
  from clearml import Task, Logger
 
21
 
22
+ # Setting up the SEED to be able to repeat experiments
23
  SEED = 42
24
  random.seed(SEED)
25
  np.random.seed(SEED)
 
33
  task.set_random_seed(SEED)
34
  clearml_logger = task.get_logger()
35
 
36
+ print("βœ… Checkpoint: Imports, SEED and ClearML are set")
37
 
38
+
39
+ # Loading dataset from HugginFace and checking it
40
  try:
41
  ds = load_dataset("DScomp380/plant_village")
42
  except Exception as e:
 
44
 
45
  data_plants = ds['train']
46
 
47
+ print("--- Verification ---")
48
  # Verification
49
  print(f"\nLoaded object type: {type(data_plants)}")
50
+ print("\n --- \n")
51
 
52
  data_length = len(data_plants)
53
  print(f"\nLoaded object size: {data_length}")
54
+ print("\n --- \n")
55
 
56
  features = data_plants.features
57
  print(f"\nDataset features: {features}")
58
+ print("\n --- \n")
59
 
60
  # Verifying label count
61
  if 'label' in features and hasattr(features['label'], 'num_classes'):
 
63
  print(f"Number of disease categories (labels): {label_count}")
64
  else:
65
  print("Couldnt determine the labels automatically")
66
+ print("\n --- \n")
67
+
68
 
69
  # Verifying single sample
70
  sample = data_plants[0]
71
  print(f"Sample image type: {type(sample['image'])}")
72
  print(f"Sample label: {sample['label']}")
73
+ print("\n --- \n")
74
+
75
+ print("βœ… Checkpoint: Dataset is loaded and data is checked")
76
 
77
+
78
+ # --------------------------- Data selection --------------------------------
79
  # Creating the prototyping dataset
80
  SUBSET_RATIO = 0.25 # 25% for prototyping
81
 
 
94
 
95
  prototyping_dataset = data_plants.select(subset_indices)
96
 
97
+ print("βœ… Checkpoint: Prototyping dataset is created")
98
  #Verifying
99
  print(f"Prototyping dataset size: {len(prototyping_dataset)}")
100
 
101
+ # ---- Exploratory data analysis (EDA) ----
 
 
 
102
 
103
  # Reformatting the label feature to understand bias
104
  labels_list = prototyping_dataset['label']
 
131
  iteration=1
132
  )
133
 
134
+ print("--- Class imbalance analysis --- ")
135
  print(f"Max labels in a class: {max_count}")
136
  print(f"Min labels in a class: {min_count}")
137
  print(f"Mean labels in a class: {mean_count}")
138
  print(f"Imbalance ratio: {max_count/min_count:.2f}")
139
+ print("βœ… Checkpoint: Class distribution is calculated")
140
 
141
  # Mapping indeces to class names
142
  class_names = features['label'].names
 
159
  local_path=plot_file # The path to the file you just saved
160
  )
161
 
162
+ # To see the plot uncomment but itll pause the code
163
+ #plt.show()
164
+ print("βœ… Checkpoint: Plot with classes distributions is created and saved")
165
+ # --------------- Data Splits ------------
166
+
167
+ # Standard ImageNet mean and std
168
+ # These values are used to normalize the tensors
169
+ IMAGENET_MEAN = [0.485, 0.456, 0.406]
170
+ IMAGENET_STD = [0.229, 0.224, 0.225]
171
+
172
+ # Defining pipeline to ensure that images are consistently formatted (for Val/Test)
173
+ normalisation_pipeline = transforms.Compose([
174
+ # Convert PIL Image to a PyTorch Tensor
175
+ # This also scales pixel values from [0, 255] to [0.0, 1.0]
176
+ transforms.ToTensor(),
177
+
178
+ # Normalise the Tensor; Standartises pixel values
179
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
180
+ ])
181
+ print("βœ… Checkpoint: Transform pipeline created")
182
+
183
+ # Augmentation pipeline (to change some parameters of the pictures to create "new" ones)
184
+ augmentation_pipeline = transforms.Compose([
185
+ # Randomly changing some parameters of pictures to enrich dataset
186
+ transforms.RandomRotation(degrees=30),
187
+ transforms.ColorJitter(brightness=0.2, saturation=0.2),
188
+ transforms.GaussianBlur(kernel_size=3),
189
+
190
+ # Convert to Tensor and Normalise
191
+ transforms.ToTensor(),
192
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
193
+ ])
194
+
195
+ print("βœ… Checkpoint: Augmentation pipeline created")
196
+
197
+
198
+ # -- Split the prototype dataset --
199
+ # This returns a dictionary: {'train': 70%, 'test': 30%}
200
+ split_1_dict = prototyping_dataset.train_test_split(test_size=0.3, seed=SEED)
201
+
202
+ # Assign the 70% part to final train split
203
+ proto_train_split = split_1_dict['train']
204
+
205
+ # Assign the 30% part to a temporary var
206
+ proto_temp_split = split_1_dict['test']
207
+
208
+ # Split 30% into 2 15%
209
+ # This returns a dictionary: {'train': 50%, 'test': 50%}
210
+ split_2_dict = proto_temp_split.train_test_split(test_size=0.5, seed=SEED)
211
+
212
+ proto_val_split = split_2_dict['train']
213
+ proto_test_split = split_2_dict['test']
214
+
215
+ print("βœ… Checkpoint: Dataset splitted")
216
+
217
+ # -- Putting splits through pipelines --
218
+ proto_train_split.set_transform(augmentation_pipeline)
219
+ proto_val_split.set_transform(normalisation_pipeline)
220
+ proto_test_split.set_transform(normalisation_pipeline)
221
+
222
+ # -- Creating the prototype dataloaders --
223
+ BATCH_SIZE = 32
224
+ proto_train_loader = DataLoader(dataset = proto_train_split, batch_size = BATCH_SIZE, shuffle = True )
225
+ proto_val_loader = DataLoader(dataset = proto_val_split, batch_size = BATCH_SIZE, shuffle = False )
226
+ proto_test_loader = DataLoader(dataset = proto_test_split, batch_size = BATCH_SIZE, shuffle = False )
227
+
228
+ print("βœ… Checkpoint: DataLoaders are set")