| import os |
| import numpy as np |
| import json |
| import argparse |
| import shutil |
| from loguru import logger |
|
|
|
|
| def process_dataset(input_dir, output_dir, eliminate_top_k_makespans, train_ratio, duplication_factor, seed): |
| """ |
| Splits the dataset in input_dir into train and validation sets. |
| Saves the splits in output_dir with names like schedules_train.npy, etc. |
| """ |
| logger.add(os.path.join(output_dir, "process_dataset.log")) |
| logger.info(f"Starting dataset processing. Input: {input_dir} -> Output: {output_dir}") |
| |
| |
| metadata_path = os.path.join(input_dir, "metadata.json") |
| schedules_path = os.path.join(input_dir, "schedules.npy") |
| makespans_path = os.path.join(input_dir, "makespans.npy") |
| |
| if not os.path.exists(metadata_path): |
| logger.error(f"Metadata file not found: {metadata_path}") |
| return |
|
|
| |
| with open(metadata_path, "r") as f: |
| metadata = json.load(f) |
| |
| |
| logger.info("Loading data from .npy files...") |
| schedules = np.load(schedules_path) |
| makespans = np.load(makespans_path) |
| logger.info(f"Loaded {schedules.shape[0]} schedules :=: {makespans.shape[0]} makespans.") |
|
|
| |
| min_makespan = np.min(makespans[:,-1]) |
| np.save(os.path.join(output_dir, "min_makespan.npy"), min_makespan) |
| |
| |
| for _ in range(eliminate_top_k_makespans): |
| min_makespan = np.min(makespans[:,-1]) |
| keep_indices = makespans[:,-1] != min_makespan |
| removed_count = np.sum(~keep_indices) |
| logger.info(f"Removing {removed_count} samples with best makespan ({min_makespan}).") |
| schedules = schedules[keep_indices] |
| makespans = makespans[keep_indices] |
| |
| num_samples = schedules.shape[0] |
| logger.info(f"Top k is {eliminate_top_k_makespans}. Filtered dataset contains {num_samples} samples.") |
| indices = np.arange(num_samples) |
| |
| |
| logger.info(f"Shuffling with seed {seed}...") |
| np.random.seed(seed) |
| np.random.shuffle(indices) |
| |
| |
| split_idx = int(num_samples * train_ratio) |
| train_indices = indices[:split_idx] |
| val_indices = indices[split_idx:] |
| logger.info(f"Train ratio is {train_ratio}. Splitting into {len(train_indices)} train and {len(val_indices)} val samples.") |
|
|
| |
| nb_added_samples = int(len(train_indices) * duplication_factor) |
| train_indices = np.concatenate([train_indices, np.random.choice(train_indices, size=nb_added_samples, replace=True)]) |
| logger.info(f"Duplication factor is {duplication_factor}. Added {nb_added_samples} duplicated training samples. Total train samples is now: {len(train_indices)}") |
| |
| |
| os.makedirs(output_dir, exist_ok=True) |
| |
| |
| logger.info(f"Saving {len(train_indices)} training samples...") |
| np.save(os.path.join(output_dir, "schedules_train.npy"), schedules[train_indices]) |
| np.save(os.path.join(output_dir, "makespans_train.npy"), makespans[train_indices]) |
| |
| |
| logger.info(f"Saving {len(val_indices)} validation samples...") |
| np.save(os.path.join(output_dir, "schedules_val.npy"), schedules[val_indices]) |
| np.save(os.path.join(output_dir, "makespans_val.npy"), makespans[val_indices]) |
| |
| |
| pfsp_path = os.path.join(input_dir, "pfsp_instance.npy") |
| shutil.copy(pfsp_path, os.path.join(output_dir, "pfsp_instance.npy")) |
| |
| |
| split_metadata = metadata.copy() |
| split_metadata["nb_train_samples"] = len(train_indices) |
| split_metadata["nb_val_samples"] = len(val_indices) |
| split_metadata["data_source"] = input_dir |
| with open(os.path.join(output_dir, "metadata.json"), "w") as f: |
| json.dump(split_metadata, f, indent=4) |
| |
| logger.success(f"Successfully split dataset into {len(train_indices)} train and {len(val_indices)} val samples in {output_dir}.") |
|
|
| |
| for filename in ["neh_makespan.npy", "cds_makespan.npy", "palmer_makespan.npy"]: |
| src_path = os.path.join(input_dir, filename) |
| shutil.copy(src_path, os.path.join(output_dir, filename)) |
|
|
| |
|
|
|
|
| if __name__ == "__main__": |
| from argparse import ArgumentParser |
| parser = ArgumentParser(description="Split flowshop dataset into train and validation sets.") |
| parser.add_argument("--input_dir", type=str, required=True, help="Path to the source dataset directory") |
| parser.add_argument("--output_dir", type=str, required=True, help="Path to the output directory") |
| parser.add_argument("--eliminate_top_k_makespans", type=int, required=True, help="Number of top makespans to eliminate") |
| parser.add_argument("--train_ratio", type=float, required=True, help="Train/Val split ratio") |
| parser.add_argument("--duplication_factor", type=float, required=True, help="Factor by which to duplicate the base samples (e.g., 0.5 means 50% more samples will be created by duplicating the base samples)") |
| parser.add_argument("--seed", type=int, required=True, help="Random seed for shuffling") |
| args = parser.parse_args() |
| |
| process_dataset( |
| args.input_dir, |
| args.output_dir, |
| args.eliminate_top_k_makespans, |
| args.train_ratio, |
| args.duplication_factor, |
| args.seed, |
| ) |
| |