import os import numpy as np from clearml import Task, Dataset from datasets import load_dataset from dataPrep.helpers.transforms_loaders import make_dataset_loaders ''' Takes latest Data Prep ClearML task from project and reconstruct: - data loaders for both full and subset datasets - Aug settings used ''' def extract_latest_data_task(project_name: str = "Small Group Project", num_workers: int = 0): # --------- Get latest Data Preparation task from ClearML --------- all_tasks = Task.get_tasks( project_name=f'{project_name}/Data Preparation', allow_archived=False, task_filter={'order_by': ["-last_update"]}, ) if not all_tasks: raise RuntimeError(f"No tasks found in project '{project_name}'") dp_tasks = [ t for t in all_tasks if t.task_type == Task.TaskTypes.data_processing and t.completed is not None ] if not dp_tasks: raise RuntimeError("No 'Data Preparation' tasks found in this project!") # Latest Data Prep Task latest_task = dp_tasks[0] DYNAMIC_TASK_ID = latest_task.id DATA_PREP = Task.get_task(task_id=DYNAMIC_TASK_ID) # Load subset indices artifact from Data Prep task artifacts = DATA_PREP.artifacts if "subset_indices" not in artifacts: raise RuntimeError("Data Prep task did not upload 'subset_indices' artifact!") artifact = artifacts["subset_indices"] subset_indices_path = artifact.get_local_copy() subset_indices = np.load(subset_indices_path) # Load dataset metadata from Data Prep task data_params = DATA_PREP.get_parameters() subset_ratio = float(data_params['General/dataset/subset_ratio']) dataset_link = data_params['General/dataset/link'] seed = int(data_params['General/seed']) batch_size = int(data_params['General/dataloaders/batch_size']) test_size = float(data_params['General/dataloaders/test_size']) aug_config = { 'rotation': float(data_params['General/augmentation/rotation']), 'brightness': float(data_params['General/augmentation/brightness']), 'saturation': float(data_params['General/augmentation/saturation']), 'blur': float(data_params['General/augmentation/blur']), } # Load Full Dataset try: ds = load_dataset(dataset_link) except Exception as e: raise RuntimeError(f"Error loading the dataset: {e}") full_dataset = ds['train'] # Apply subset indices to full dataset - this gives you the same subset as data prep subset_dataset = full_dataset.select(subset_indices) # Get data loaders for both full and subset datasets subset_loaders, full_loaders, aug_config = get_data_loaders(data_params, subset_dataset, full_dataset, num_workers=num_workers) batch_size = int(data_params['General/dataloaders/batch_size']) seed = int(data_params['General/seed']) # Gather data prep task metadata data_prep_metadata = { "data_prep_task_id": DYNAMIC_TASK_ID, "dataset_link": dataset_link, "subset_ratio_used": subset_ratio, "augmentation_used": aug_config, "batch_size_used": batch_size, "seed_used": seed, "test_size_used": test_size } return subset_loaders, full_loaders, data_prep_metadata ''' Takes a given dataset, subset, data params to create DataLoaders Loaders split data into train, val, test ''' def get_data_loaders(data_params, subset_dataset, full_dataset, num_workers): # Extract data parameters- these will be used in the DataLoaders seed = int(data_params['General/seed']) batch_size = int(data_params['General/dataloaders/batch_size']) test_size = float(data_params['General/dataloaders/test_size']) aug_config = { 'rotation': float(data_params['General/augmentation/rotation']), 'brightness': float(data_params['General/augmentation/brightness']), 'saturation': float(data_params['General/augmentation/saturation']), 'blur': float(data_params['General/augmentation/blur']) } # Create DataLoaders using the parameters from data prep subset_loaders = make_dataset_loaders( subset_dataset, seed, batch_size, test_size, aug_config, workers=num_workers ) print("\n--- Handoff Test Successful ---") print(f"Prototype Train loader batches: {len(subset_loaders['train'])}") print(f"Prototype Validation loader batches: {len(subset_loaders['val'])}") print(f"Prototype Test loader batches: {len(subset_loaders['test'])}") full_loaders = make_dataset_loaders( full_dataset, seed, batch_size, test_size, aug_config, workers=num_workers ) print("\n--- Handoff Test Successful ---") print(f"Train loader batches: {len(full_loaders['train'])}") print(f"Validation loader batches: {len(full_loaders['val'])}") print(f"Test loader batches: {len(full_loaders['test'])}") return subset_loaders, full_loaders, aug_config