File size: 5,480 Bytes
2d10cee
 
 
 
 
 
 
 
7569568
2d10cee
 
 
 
7569568
 
2d10cee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7569568
 
 
 
 
2d10cee
 
a897e4f
 
 
2d10cee
 
 
 
7569568
2d10cee
7569568
2d10cee
 
 
 
 
 
 
7569568
a897e4f
2d10cee
 
7569568
 
 
 
 
 
2d10cee
 
 
 
 
 
 
 
 
 
 
 
 
 
7569568
2d10cee
7569568
2d10cee
 
 
 
 
 
 
 
 
 
 
7569568
 
 
 
 
 
 
 
2d10cee
7569568
 
a897e4f
 
7569568
a897e4f
7569568
a897e4f
2d10cee
 
7569568
a897e4f
 
 
7569568
 
 
2d10cee
7569568
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import numpy as np
import json
import argparse
import shutil
from loguru import logger


def process_dataset(input_dir, output_dir, eliminate_top_k_makespans, train_ratio, duplication_factor, seed):
    """
    Splits the dataset in input_dir into train and validation sets.
    Saves the splits in output_dir with names like schedules_train.npy, etc.
    """
    logger.add(os.path.join(output_dir, "process_dataset.log"))
    logger.info(f"Starting dataset processing. Input: {input_dir} -> Output: {output_dir}")
    
    # paths
    metadata_path = os.path.join(input_dir, "metadata.json")
    schedules_path = os.path.join(input_dir, "schedules.npy")
    makespans_path = os.path.join(input_dir, "makespans.npy")
    
    if not os.path.exists(metadata_path):
        logger.error(f"Metadata file not found: {metadata_path}")
        return

    # load metadata
    with open(metadata_path, "r") as f:
        metadata = json.load(f)
    
    # load data
    logger.info("Loading data from .npy files...")
    schedules = np.load(schedules_path)
    makespans = np.load(makespans_path)
    logger.info(f"Loaded {schedules.shape[0]} schedules :=: {makespans.shape[0]} makespans.")

    # save min makespan
    min_makespan = np.min(makespans[:,-1])
    np.save(os.path.join(output_dir, "min_makespan.npy"), min_makespan)
    
    # filtering logic: remove best makespans if requested
    for _ in range(eliminate_top_k_makespans):
        min_makespan = np.min(makespans[:,-1])
        keep_indices = makespans[:,-1] != min_makespan
        removed_count = np.sum(~keep_indices)
        logger.info(f"Removing {removed_count} samples with best makespan ({min_makespan}).")
        schedules = schedules[keep_indices]
        makespans = makespans[keep_indices]
    # ======
    num_samples = schedules.shape[0]
    logger.info(f"Top k is {eliminate_top_k_makespans}. Filtered dataset contains {num_samples} samples.")
    indices = np.arange(num_samples)
    
    # shuffle
    logger.info(f"Shuffling with seed {seed}...")
    np.random.seed(seed)
    np.random.shuffle(indices)
    
    # splitting
    split_idx = int(num_samples * train_ratio)
    train_indices = indices[:split_idx]
    val_indices = indices[split_idx:]
    logger.info(f"Train ratio is {train_ratio}. Splitting into {len(train_indices)} train and {len(val_indices)} val samples.")

    # duplicate the train samples as a data augmentation if requested
    nb_added_samples = int(len(train_indices) * duplication_factor)
    train_indices = np.concatenate([train_indices, np.random.choice(train_indices, size=nb_added_samples, replace=True)])
    logger.info(f"Duplication factor is {duplication_factor}. Added {nb_added_samples} duplicated training samples. Total train samples is now: {len(train_indices)}")
    
    # create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # save train data
    logger.info(f"Saving {len(train_indices)} training samples...")
    np.save(os.path.join(output_dir, "schedules_train.npy"), schedules[train_indices])
    np.save(os.path.join(output_dir, "makespans_train.npy"), makespans[train_indices])
    
    # save val data
    logger.info(f"Saving {len(val_indices)} validation samples...")
    np.save(os.path.join(output_dir, "schedules_val.npy"), schedules[val_indices])
    np.save(os.path.join(output_dir, "makespans_val.npy"), makespans[val_indices])
    
    # copy pfsp_instance.npy
    pfsp_path = os.path.join(input_dir, "pfsp_instance.npy")
    shutil.copy(pfsp_path, os.path.join(output_dir, "pfsp_instance.npy"))
    
    # create metadata file for the split dataset
    split_metadata = metadata.copy()
    split_metadata["nb_train_samples"] = len(train_indices)
    split_metadata["nb_val_samples"] = len(val_indices)
    split_metadata["data_source"] = input_dir
    with open(os.path.join(output_dir, "metadata.json"), "w") as f:
        json.dump(split_metadata, f, indent=4)
            
    logger.success(f"Successfully split dataset into {len(train_indices)} train and {len(val_indices)} val samples in {output_dir}.")

    # copy the heuristics results
    for filename in ["neh_makespan.npy", "cds_makespan.npy", "palmer_makespan.npy"]:
        src_path = os.path.join(input_dir, filename)
        shutil.copy(src_path, os.path.join(output_dir, filename))

# ======


if __name__ == "__main__":
    from argparse import ArgumentParser
    parser = ArgumentParser(description="Split flowshop dataset into train and validation sets.")
    parser.add_argument("--input_dir", type=str, required=True, help="Path to the source dataset directory")
    parser.add_argument("--output_dir", type=str, required=True, help="Path to the output directory")
    parser.add_argument("--eliminate_top_k_makespans", type=int, required=True, help="Number of top makespans to eliminate")
    parser.add_argument("--train_ratio", type=float, required=True, help="Train/Val split ratio")
    parser.add_argument("--duplication_factor", type=float, required=True, help="Factor by which to duplicate the base samples (e.g., 0.5 means 50% more samples will be created by duplicating the base samples)")
    parser.add_argument("--seed", type=int, required=True, help="Random seed for shuffling")
    args = parser.parse_args()
    
    process_dataset(
        args.input_dir, 
        args.output_dir,
        args.eliminate_top_k_makespans,
        args.train_ratio,
        args.duplication_factor,
        args.seed,
    )
# ======