ra1425 commited on
Commit
6290586
·
1 Parent(s): 4839f9d

FEAT: Add final data loaders for the whole dataset

Browse files
Files changed (1) hide show
  1. data_preparation.py +68 -33
data_preparation.py CHANGED
@@ -1,5 +1,4 @@
1
  # --- Standard Python Library ---
2
- import os
3
  import random
4
 
5
  # --- Data Handling & Analysis ---
@@ -17,7 +16,7 @@ from torchvision import transforms
17
  from torch.utils.data import DataLoader
18
 
19
  # --- Experiment Tracking ---
20
- from clearml import Task, Logger
21
 
22
 
23
  # Setting up the SEED to be able to repeat experiments
@@ -170,37 +169,8 @@ print("✅ Checkpoint: Plot with classes distributions is created and saved")
170
 
171
  # --------------- Data Splits ------------
172
  def get_prototype_loaders(batch_size=32):
173
-
174
- # Standard ImageNet mean and std
175
- # These values are used to normalize the tensors
176
- IMAGENET_MEAN = [0.485, 0.456, 0.406]
177
- IMAGENET_STD = [0.229, 0.224, 0.225]
178
-
179
- # Defining pipeline to ensure that images are consistently formatted (for Val/Test)
180
- normalisation_pipeline = transforms.Compose([
181
- # Convert PIL Image to a PyTorch Tensor
182
- # This also scales pixel values from [0, 255] to [0.0, 1.0]
183
- transforms.ToTensor(),
184
-
185
- # Normalise the Tensor; Standartises pixel values
186
- transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
187
- ])
188
- print("✅ Checkpoint: Transform pipeline created")
189
-
190
- # Augmentation pipeline (to change some parameters of the pictures to create "new" ones)
191
- augmentation_pipeline = transforms.Compose([
192
- # Randomly changing some parameters of pictures to enrich dataset
193
- transforms.RandomRotation(degrees=30),
194
- transforms.ColorJitter(brightness=0.2, saturation=0.2),
195
- transforms.GaussianBlur(kernel_size=3),
196
-
197
- # Convert to Tensor and Normalise
198
- transforms.ToTensor(),
199
- transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
200
- ])
201
-
202
- print("✅ Checkpoint: Augmentation pipeline created")
203
-
204
 
205
  # -- Split the prototype dataset --
206
  # This returns a dictionary: {'train': 70%, 'test': 30%}
@@ -235,6 +205,71 @@ def get_prototype_loaders(batch_size=32):
235
  return proto_train_loader, proto_val_loader, proto_test_loader
236
 
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  # ----------------------------------------------------------------------
239
  if __name__ == "__main__":
240
 
 
1
  # --- Standard Python Library ---
 
2
  import random
3
 
4
  # --- Data Handling & Analysis ---
 
16
  from torch.utils.data import DataLoader
17
 
18
  # --- Experiment Tracking ---
19
+ from clearml import Task, Dataset
20
 
21
 
22
  # Setting up the SEED to be able to repeat experiments
 
169
 
170
  # --------------- Data Splits ------------
171
  def get_prototype_loaders(batch_size=32):
172
+ # Calling function to define pipelines
173
+ normalisation_pipeline, augmentation_pipeline = get_transform_pipelines()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
  # -- Split the prototype dataset --
176
  # This returns a dictionary: {'train': 70%, 'test': 30%}
 
205
  return proto_train_loader, proto_val_loader, proto_test_loader
206
 
207
 
208
+
209
+ def get_final_loaders(batch_size=32):
210
+ # Calling function to define pipelines
211
+ normalisation_pipeline, augmentation_pipeline = get_transform_pipelines()
212
+
213
+ # -- Split the prototype dataset --
214
+ # This returns a dictionary: {'train': 70%, 'test': 30%}
215
+ split_1_dict = data_plants.train_test_split(test_size=0.3, seed=SEED)
216
+
217
+ # Assign the 70% part to final train split
218
+ train_split = split_1_dict['train']
219
+
220
+ # Assign the 30% part to a temporary var
221
+ temp_split = split_1_dict['test']
222
+
223
+ # Split 30% into 2 15%
224
+ # This returns a dictionary: {'train': 50%, 'test': 50%}
225
+ split_2_dict = temp_split.train_test_split(test_size=0.5, seed=SEED)
226
+
227
+ val_split = split_2_dict['train']
228
+ test_split = split_2_dict['test']
229
+
230
+ print("✅ Checkpoint: Dataset splitted")
231
+
232
+ # -- Putting splits through pipelines --
233
+ train_split.set_transform(augmentation_pipeline)
234
+ val_split.set_transform(normalisation_pipeline)
235
+ test_split.set_transform(normalisation_pipeline)
236
+
237
+ # -- Creating the prototype dataloaders --
238
+ train_loader = DataLoader(dataset = train_split, batch_size = batch_size, shuffle = True )
239
+ val_loader = DataLoader(dataset = val_split, batch_size = batch_size, shuffle = False )
240
+ test_loader = DataLoader(dataset = test_split, batch_size = batch_size, shuffle = False )
241
+
242
+ print("✅ Checkpoint: DataLoaders are set")
243
+ return train_loader, val_loader, test_loader
244
+
245
+
246
+ def get_transform_pipelines():
247
+ """
248
+ Defines and returns the normalization and augmentation pipelines.
249
+ """
250
+ IMAGENET_MEAN = [0.485, 0.456, 0.406]
251
+ IMAGENET_STD = [0.229, 0.224, 0.225]
252
+
253
+ # Defining pipeline for Val/Test
254
+ normalisation_pipeline = transforms.Compose([
255
+ transforms.ToTensor(),
256
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
257
+ ])
258
+
259
+ # Augmentation pipeline for Train
260
+ augmentation_pipeline = transforms.Compose([
261
+ transforms.RandomRotation(degrees=30),
262
+ transforms.ColorJitter(brightness=0.2, saturation=0.2),
263
+ transforms.GaussianBlur(kernel_size=3),
264
+ transforms.ToTensor(),
265
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
266
+ ])
267
+
268
+ print("✅ Checkpoint: Transform pipelines created")
269
+
270
+ # Return both pipelines
271
+ return normalisation_pipeline, augmentation_pipeline
272
+
273
  # ----------------------------------------------------------------------
274
  if __name__ == "__main__":
275