Spaces:
Sleeping
Sleeping
ra1425
commited on
Commit
·
6f31a0a
1
Parent(s):
42b46e3
REF: Rearrange code for readability
Browse files- data_preparation.py +8 -4
data_preparation.py
CHANGED
|
@@ -19,6 +19,7 @@ from torch.utils.data import DataLoader
|
|
| 19 |
# --- Experiment Tracking ---
|
| 20 |
from clearml import Task, Logger
|
| 21 |
|
|
|
|
| 22 |
# Setting up the SEED to be able to repeat experiments
|
| 23 |
SEED = 42
|
| 24 |
random.seed(SEED)
|
|
@@ -84,6 +85,7 @@ task.connect_configuration(
|
|
| 84 |
{"subset_ratio": SUBSET_RATIO},
|
| 85 |
name="Data subsetting"
|
| 86 |
)
|
|
|
|
| 87 |
# Calculate amount of samples we use
|
| 88 |
subset_size = int(data_length * SUBSET_RATIO)
|
| 89 |
|
|
@@ -98,6 +100,7 @@ print("✅ Checkpoint: Prototyping dataset is created")
|
|
| 98 |
#Verifying
|
| 99 |
print(f"Prototyping dataset size: {len(prototyping_dataset)}")
|
| 100 |
|
|
|
|
| 101 |
# ---- Exploratory data analysis (EDA) ----
|
| 102 |
|
| 103 |
# Reformatting the label feature to understand bias
|
|
@@ -152,6 +155,7 @@ plt.title('Class distribution among chosen samples')
|
|
| 152 |
|
| 153 |
plot_file = 'class_distribution.png'
|
| 154 |
plt.savefig(plot_file)
|
|
|
|
| 155 |
clearml_logger.report_image(
|
| 156 |
title="EDA", # The title for the plot section in ClearML
|
| 157 |
series="Class Distribution", # The name of this specific plot
|
|
@@ -165,7 +169,7 @@ print("✅ Checkpoint: Plot with classes distributions is created and saved")
|
|
| 165 |
|
| 166 |
|
| 167 |
# --------------- Data Splits ------------
|
| 168 |
-
def get_prototype_loaders(batch_size=32)
|
| 169 |
|
| 170 |
# Standard ImageNet mean and std
|
| 171 |
# These values are used to normalize the tensors
|
|
@@ -223,9 +227,9 @@ def get_prototype_loaders(batch_size=32) :
|
|
| 223 |
proto_test_split.set_transform(normalisation_pipeline)
|
| 224 |
|
| 225 |
# -- Creating the prototype dataloaders --
|
| 226 |
-
proto_train_loader = DataLoader(dataset = proto_train_split, batch_size = batch_size, shuffle = True)
|
| 227 |
-
proto_val_loader = DataLoader(dataset = proto_val_split, batch_size = batch_size, shuffle = False)
|
| 228 |
-
proto_test_loader = DataLoader(dataset = proto_test_split, batch_size = batch_size, shuffle = False)
|
| 229 |
|
| 230 |
print("✅ Checkpoint: DataLoaders are set")
|
| 231 |
return proto_train_loader, proto_val_loader, proto_test_loader
|
|
|
|
| 19 |
# --- Experiment Tracking ---
|
| 20 |
from clearml import Task, Logger
|
| 21 |
|
| 22 |
+
|
| 23 |
# Setting up the SEED to be able to repeat experiments
|
| 24 |
SEED = 42
|
| 25 |
random.seed(SEED)
|
|
|
|
| 85 |
{"subset_ratio": SUBSET_RATIO},
|
| 86 |
name="Data subsetting"
|
| 87 |
)
|
| 88 |
+
|
| 89 |
# Calculate amount of samples we use
|
| 90 |
subset_size = int(data_length * SUBSET_RATIO)
|
| 91 |
|
|
|
|
| 100 |
#Verifying
|
| 101 |
print(f"Prototyping dataset size: {len(prototyping_dataset)}")
|
| 102 |
|
| 103 |
+
|
| 104 |
# ---- Exploratory data analysis (EDA) ----
|
| 105 |
|
| 106 |
# Reformatting the label feature to understand bias
|
|
|
|
| 155 |
|
| 156 |
plot_file = 'class_distribution.png'
|
| 157 |
plt.savefig(plot_file)
|
| 158 |
+
|
| 159 |
clearml_logger.report_image(
|
| 160 |
title="EDA", # The title for the plot section in ClearML
|
| 161 |
series="Class Distribution", # The name of this specific plot
|
|
|
|
| 169 |
|
| 170 |
|
| 171 |
# --------------- Data Splits ------------
|
| 172 |
+
def get_prototype_loaders(batch_size=32):
|
| 173 |
|
| 174 |
# Standard ImageNet mean and std
|
| 175 |
# These values are used to normalize the tensors
|
|
|
|
| 227 |
proto_test_split.set_transform(normalisation_pipeline)
|
| 228 |
|
| 229 |
# -- Creating the prototype dataloaders --
|
| 230 |
+
proto_train_loader = DataLoader(dataset = proto_train_split, batch_size = batch_size, shuffle = True )
|
| 231 |
+
proto_val_loader = DataLoader(dataset = proto_val_split, batch_size = batch_size, shuffle = False )
|
| 232 |
+
proto_test_loader = DataLoader(dataset = proto_test_split, batch_size = batch_size, shuffle = False )
|
| 233 |
|
| 234 |
print("✅ Checkpoint: DataLoaders are set")
|
| 235 |
return proto_train_loader, proto_val_loader, proto_test_loader
|