import os import numpy as np import json import argparse import shutil from loguru import logger def process_dataset(input_dir, output_dir, eliminate_top_k_makespans, train_ratio, duplication_factor, seed): """ Splits the dataset in input_dir into train and validation sets. Saves the splits in output_dir with names like schedules_train.npy, etc. """ logger.add(os.path.join(output_dir, "process_dataset.log")) logger.info(f"Starting dataset processing. Input: {input_dir} -> Output: {output_dir}") # paths metadata_path = os.path.join(input_dir, "metadata.json") schedules_path = os.path.join(input_dir, "schedules.npy") makespans_path = os.path.join(input_dir, "makespans.npy") if not os.path.exists(metadata_path): logger.error(f"Metadata file not found: {metadata_path}") return # load metadata with open(metadata_path, "r") as f: metadata = json.load(f) # load data logger.info("Loading data from .npy files...") schedules = np.load(schedules_path) makespans = np.load(makespans_path) logger.info(f"Loaded {schedules.shape[0]} schedules :=: {makespans.shape[0]} makespans.") # save min makespan min_makespan = np.min(makespans[:,-1]) np.save(os.path.join(output_dir, "min_makespan.npy"), min_makespan) # filtering logic: remove best makespans if requested for _ in range(eliminate_top_k_makespans): min_makespan = np.min(makespans[:,-1]) keep_indices = makespans[:,-1] != min_makespan removed_count = np.sum(~keep_indices) logger.info(f"Removing {removed_count} samples with best makespan ({min_makespan}).") schedules = schedules[keep_indices] makespans = makespans[keep_indices] # ====== num_samples = schedules.shape[0] logger.info(f"Top k is {eliminate_top_k_makespans}. Filtered dataset contains {num_samples} samples.") indices = np.arange(num_samples) # shuffle logger.info(f"Shuffling with seed {seed}...") np.random.seed(seed) np.random.shuffle(indices) # splitting split_idx = int(num_samples * train_ratio) train_indices = indices[:split_idx] val_indices = indices[split_idx:] logger.info(f"Train ratio is {train_ratio}. Splitting into {len(train_indices)} train and {len(val_indices)} val samples.") # duplicate the train samples as a data augmentation if requested nb_added_samples = int(len(train_indices) * duplication_factor) train_indices = np.concatenate([train_indices, np.random.choice(train_indices, size=nb_added_samples, replace=True)]) logger.info(f"Duplication factor is {duplication_factor}. Added {nb_added_samples} duplicated training samples. Total train samples is now: {len(train_indices)}") # create output directory os.makedirs(output_dir, exist_ok=True) # save train data logger.info(f"Saving {len(train_indices)} training samples...") np.save(os.path.join(output_dir, "schedules_train.npy"), schedules[train_indices]) np.save(os.path.join(output_dir, "makespans_train.npy"), makespans[train_indices]) # save val data logger.info(f"Saving {len(val_indices)} validation samples...") np.save(os.path.join(output_dir, "schedules_val.npy"), schedules[val_indices]) np.save(os.path.join(output_dir, "makespans_val.npy"), makespans[val_indices]) # copy pfsp_instance.npy pfsp_path = os.path.join(input_dir, "pfsp_instance.npy") shutil.copy(pfsp_path, os.path.join(output_dir, "pfsp_instance.npy")) # create metadata file for the split dataset split_metadata = metadata.copy() split_metadata["nb_train_samples"] = len(train_indices) split_metadata["nb_val_samples"] = len(val_indices) split_metadata["data_source"] = input_dir with open(os.path.join(output_dir, "metadata.json"), "w") as f: json.dump(split_metadata, f, indent=4) logger.success(f"Successfully split dataset into {len(train_indices)} train and {len(val_indices)} val samples in {output_dir}.") # copy the heuristics results for filename in ["neh_makespan.npy", "cds_makespan.npy", "palmer_makespan.npy"]: src_path = os.path.join(input_dir, filename) shutil.copy(src_path, os.path.join(output_dir, filename)) # ====== if __name__ == "__main__": from argparse import ArgumentParser parser = ArgumentParser(description="Split flowshop dataset into train and validation sets.") parser.add_argument("--input_dir", type=str, required=True, help="Path to the source dataset directory") parser.add_argument("--output_dir", type=str, required=True, help="Path to the output directory") parser.add_argument("--eliminate_top_k_makespans", type=int, required=True, help="Number of top makespans to eliminate") parser.add_argument("--train_ratio", type=float, required=True, help="Train/Val split ratio") parser.add_argument("--duplication_factor", type=float, required=True, help="Factor by which to duplicate the base samples (e.g., 0.5 means 50% more samples will be created by duplicating the base samples)") parser.add_argument("--seed", type=int, required=True, help="Random seed for shuffling") args = parser.parse_args() process_dataset( args.input_dir, args.output_dir, args.eliminate_top_k_makespans, args.train_ratio, args.duplication_factor, args.seed, ) # ======