flowshop_transformer / source /process_dataset.py
younadi's picture
added the jolt dynamics and comparison with heuristics
ff1ceb8
import os
import numpy as np
import json
import argparse
import shutil
from loguru import logger
def process_dataset(input_dir, output_dir, eliminate_top_k_makespans, train_ratio, duplication_factor, seed):
"""
Splits the dataset in input_dir into train and validation sets.
Saves the splits in output_dir with names like schedules_train.npy, etc.
"""
logger.add(os.path.join(output_dir, "process_dataset.log"))
logger.info(f"Starting dataset processing. Input: {input_dir} -> Output: {output_dir}")
# paths
metadata_path = os.path.join(input_dir, "metadata.json")
schedules_path = os.path.join(input_dir, "schedules.npy")
makespans_path = os.path.join(input_dir, "makespans.npy")
if not os.path.exists(metadata_path):
logger.error(f"Metadata file not found: {metadata_path}")
return
# load metadata
with open(metadata_path, "r") as f:
metadata = json.load(f)
# load data
logger.info("Loading data from .npy files...")
schedules = np.load(schedules_path)
makespans = np.load(makespans_path)
logger.info(f"Loaded {schedules.shape[0]} schedules :=: {makespans.shape[0]} makespans.")
# save min makespan
min_makespan = np.min(makespans[:,-1])
np.save(os.path.join(output_dir, "min_makespan.npy"), min_makespan)
# filtering logic: remove best makespans if requested
for _ in range(eliminate_top_k_makespans):
min_makespan = np.min(makespans[:,-1])
keep_indices = makespans[:,-1] != min_makespan
removed_count = np.sum(~keep_indices)
logger.info(f"Removing {removed_count} samples with best makespan ({min_makespan}).")
schedules = schedules[keep_indices]
makespans = makespans[keep_indices]
# ======
num_samples = schedules.shape[0]
logger.info(f"Top k is {eliminate_top_k_makespans}. Filtered dataset contains {num_samples} samples.")
indices = np.arange(num_samples)
# shuffle
logger.info(f"Shuffling with seed {seed}...")
np.random.seed(seed)
np.random.shuffle(indices)
# splitting
split_idx = int(num_samples * train_ratio)
train_indices = indices[:split_idx]
val_indices = indices[split_idx:]
logger.info(f"Train ratio is {train_ratio}. Splitting into {len(train_indices)} train and {len(val_indices)} val samples.")
# duplicate the train samples as a data augmentation if requested
nb_added_samples = int(len(train_indices) * duplication_factor)
train_indices = np.concatenate([train_indices, np.random.choice(train_indices, size=nb_added_samples, replace=True)])
logger.info(f"Duplication factor is {duplication_factor}. Added {nb_added_samples} duplicated training samples. Total train samples is now: {len(train_indices)}")
# create output directory
os.makedirs(output_dir, exist_ok=True)
# save train data
logger.info(f"Saving {len(train_indices)} training samples...")
np.save(os.path.join(output_dir, "schedules_train.npy"), schedules[train_indices])
np.save(os.path.join(output_dir, "makespans_train.npy"), makespans[train_indices])
# save val data
logger.info(f"Saving {len(val_indices)} validation samples...")
np.save(os.path.join(output_dir, "schedules_val.npy"), schedules[val_indices])
np.save(os.path.join(output_dir, "makespans_val.npy"), makespans[val_indices])
# copy pfsp_instance.npy
pfsp_path = os.path.join(input_dir, "pfsp_instance.npy")
shutil.copy(pfsp_path, os.path.join(output_dir, "pfsp_instance.npy"))
# create metadata file for the split dataset
split_metadata = metadata.copy()
split_metadata["nb_train_samples"] = len(train_indices)
split_metadata["nb_val_samples"] = len(val_indices)
split_metadata["data_source"] = input_dir
with open(os.path.join(output_dir, "metadata.json"), "w") as f:
json.dump(split_metadata, f, indent=4)
logger.success(f"Successfully split dataset into {len(train_indices)} train and {len(val_indices)} val samples in {output_dir}.")
# copy the heuristics results
for filename in ["neh_makespan.npy", "cds_makespan.npy", "palmer_makespan.npy"]:
src_path = os.path.join(input_dir, filename)
shutil.copy(src_path, os.path.join(output_dir, filename))
# ======
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser(description="Split flowshop dataset into train and validation sets.")
parser.add_argument("--input_dir", type=str, required=True, help="Path to the source dataset directory")
parser.add_argument("--output_dir", type=str, required=True, help="Path to the output directory")
parser.add_argument("--eliminate_top_k_makespans", type=int, required=True, help="Number of top makespans to eliminate")
parser.add_argument("--train_ratio", type=float, required=True, help="Train/Val split ratio")
parser.add_argument("--duplication_factor", type=float, required=True, help="Factor by which to duplicate the base samples (e.g., 0.5 means 50% more samples will be created by duplicating the base samples)")
parser.add_argument("--seed", type=int, required=True, help="Random seed for shuffling")
args = parser.parse_args()
process_dataset(
args.input_dir,
args.output_dir,
args.eliminate_top_k_makespans,
args.train_ratio,
args.duplication_factor,
args.seed,
)
# ======