Yusuf
fix dataloader worker number
84cfdfc
import os
import numpy as np
from clearml import Task, Dataset
from datasets import load_dataset
from dataPrep.helpers.transforms_loaders import make_dataset_loaders
'''
Takes latest Data Prep ClearML task from project and reconstruct:
- data loaders for both full and subset datasets
- Aug settings used
'''
def extract_latest_data_task(project_name: str = "Small Group Project", num_workers: int = 0):
# --------- Get latest Data Preparation task from ClearML ---------
all_tasks = Task.get_tasks(
project_name=f'{project_name}/Data Preparation',
allow_archived=False,
task_filter={'order_by': ["-last_update"]},
)
if not all_tasks:
raise RuntimeError(f"No tasks found in project '{project_name}'")
dp_tasks = [
t for t in all_tasks
if t.task_type == Task.TaskTypes.data_processing
and t.completed is not None
]
if not dp_tasks:
raise RuntimeError("No 'Data Preparation' tasks found in this project!")
# Latest Data Prep Task
latest_task = dp_tasks[0]
DYNAMIC_TASK_ID = latest_task.id
DATA_PREP = Task.get_task(task_id=DYNAMIC_TASK_ID)
# Load subset indices artifact from Data Prep task
artifacts = DATA_PREP.artifacts
if "subset_indices" not in artifacts:
raise RuntimeError("Data Prep task did not upload 'subset_indices' artifact!")
artifact = artifacts["subset_indices"]
subset_indices_path = artifact.get_local_copy()
subset_indices = np.load(subset_indices_path)
# Load dataset metadata from Data Prep task
data_params = DATA_PREP.get_parameters()
subset_ratio = float(data_params['General/dataset/subset_ratio'])
dataset_link = data_params['General/dataset/link']
seed = int(data_params['General/seed'])
batch_size = int(data_params['General/dataloaders/batch_size'])
test_size = float(data_params['General/dataloaders/test_size'])
aug_config = {
'rotation': float(data_params['General/augmentation/rotation']),
'brightness': float(data_params['General/augmentation/brightness']),
'saturation': float(data_params['General/augmentation/saturation']),
'blur': float(data_params['General/augmentation/blur']),
}
# Load Full Dataset
try:
ds = load_dataset(dataset_link)
except Exception as e:
raise RuntimeError(f"Error loading the dataset: {e}")
full_dataset = ds['train']
# Apply subset indices to full dataset - this gives you the same subset as data prep
subset_dataset = full_dataset.select(subset_indices)
# Get data loaders for both full and subset datasets
subset_loaders, full_loaders, aug_config = get_data_loaders(data_params, subset_dataset, full_dataset, num_workers=num_workers)
batch_size = int(data_params['General/dataloaders/batch_size'])
seed = int(data_params['General/seed'])
# Gather data prep task metadata
data_prep_metadata = {
"data_prep_task_id": DYNAMIC_TASK_ID,
"dataset_link": dataset_link,
"subset_ratio_used": subset_ratio,
"augmentation_used": aug_config,
"batch_size_used": batch_size,
"seed_used": seed,
"test_size_used": test_size
}
return subset_loaders, full_loaders, data_prep_metadata
'''
Takes a given dataset, subset, data params to create DataLoaders
Loaders split data into train, val, test
'''
def get_data_loaders(data_params, subset_dataset, full_dataset, num_workers):
# Extract data parameters- these will be used in the DataLoaders
seed = int(data_params['General/seed'])
batch_size = int(data_params['General/dataloaders/batch_size'])
test_size = float(data_params['General/dataloaders/test_size'])
aug_config = {
'rotation': float(data_params['General/augmentation/rotation']),
'brightness': float(data_params['General/augmentation/brightness']),
'saturation': float(data_params['General/augmentation/saturation']),
'blur': float(data_params['General/augmentation/blur'])
}
# Create DataLoaders using the parameters from data prep
subset_loaders = make_dataset_loaders(
subset_dataset, seed, batch_size, test_size, aug_config, workers=num_workers
)
print("\n--- Handoff Test Successful ---")
print(f"Prototype Train loader batches: {len(subset_loaders['train'])}")
print(f"Prototype Validation loader batches: {len(subset_loaders['val'])}")
print(f"Prototype Test loader batches: {len(subset_loaders['test'])}")
full_loaders = make_dataset_loaders(
full_dataset, seed, batch_size, test_size, aug_config, workers=num_workers
)
print("\n--- Handoff Test Successful ---")
print(f"Train loader batches: {len(full_loaders['train'])}")
print(f"Validation loader batches: {len(full_loaders['val'])}")
print(f"Test loader batches: {len(full_loaders['test'])}")
return subset_loaders, full_loaders, aug_config