commit
Browse files- source/.ipynb +22 -0
- source/create_dataset.bash +4 -2
- source/create_dataset.py +29 -7
- source/demos/ftd/create_dataset.log +16 -0
- source/demos/ftd/metadata.json +1 -1
- source/demos/rs_artifacts/recover_schedules.log +19 -19
- source/demos/rs_artifacts/recover_schedules_pbar.log +1 -1
- source/demos/train_artifacts/checkpoints/best_checkpoint.pth +1 -1
- source/demos/train_artifacts/checkpoints/last_checkpoint.pth +1 -1
- source/demos/train_artifacts/train.log +41 -41
- source/demos/train_artifacts/train_pbar_epoch.log +1 -1
- source/demos/train_artifacts/train_pbar_val.log +1 -1
- source/launch_create_dataset.py +23 -0
- source/launch_process_dataset.py +18 -0
- source/launch_train.py +42 -0
- source/train.py +162 -97
source/.ipynb
CHANGED
|
@@ -3,6 +3,28 @@
|
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
"execution_count": 2,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
"id": "ee46ab38",
|
| 7 |
"metadata": {},
|
| 8 |
"outputs": [
|
|
|
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
"execution_count": 2,
|
| 6 |
+
"id": "6ff2d58c",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"data": {
|
| 11 |
+
"text/plain": [
|
| 12 |
+
"False"
|
| 13 |
+
]
|
| 14 |
+
},
|
| 15 |
+
"execution_count": 2,
|
| 16 |
+
"metadata": {},
|
| 17 |
+
"output_type": "execute_result"
|
| 18 |
+
}
|
| 19 |
+
],
|
| 20 |
+
"source": [
|
| 21 |
+
"import os\n",
|
| 22 |
+
"os.path.exists(\"\")"
|
| 23 |
+
]
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"cell_type": "code",
|
| 27 |
+
"execution_count": null,
|
| 28 |
"id": "ee46ab38",
|
| 29 |
"metadata": {},
|
| 30 |
"outputs": [
|
source/create_dataset.bash
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
python create_dataset.py\
|
| 2 |
-
--testing
|
| 3 |
--nb_jobs 7\
|
| 4 |
--nb_machines 2\
|
| 5 |
--time_min 0\
|
|
@@ -8,4 +8,6 @@ python create_dataset.py\
|
|
| 8 |
--init_type exhaustive\
|
| 9 |
--output_dir "./demos/ftd"\
|
| 10 |
--seed 97\
|
| 11 |
-
--normalize_makespans "true"\
|
|
|
|
|
|
|
|
|
| 1 |
python create_dataset.py\
|
| 2 |
+
--testing False\
|
| 3 |
--nb_jobs 7\
|
| 4 |
--nb_machines 2\
|
| 5 |
--time_min 0\
|
|
|
|
| 8 |
--init_type exhaustive\
|
| 9 |
--output_dir "./demos/ftd"\
|
| 10 |
--seed 97\
|
| 11 |
+
--normalize_makespans "true"\
|
| 12 |
+
--pfsp_instance ""\
|
| 13 |
+
--autoname_output_dir ""\
|
source/create_dataset.py
CHANGED
|
@@ -6,6 +6,7 @@ import math
|
|
| 6 |
from loguru import logger
|
| 7 |
import tqdm
|
| 8 |
import time
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
def generate_random_pfsp_instance(nb_jobs, nb_machines, time_min, time_max, seed=97):
|
|
@@ -389,19 +390,36 @@ def create_dataset(
|
|
| 389 |
output_dir,
|
| 390 |
seed,
|
| 391 |
normalize_makespans,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
):
|
| 393 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
# create the output folder
|
| 395 |
os.makedirs(output_dir, exist_ok=True)
|
| 396 |
|
| 397 |
# check if experiment termination flag file exists
|
| 398 |
if not testing:
|
| 399 |
-
if os.path.exists(os.path.join(
|
| 400 |
print("Dataset creation already done. Exiting...")
|
| 401 |
return None
|
| 402 |
|
| 403 |
-
# prepare loging
|
| 404 |
-
logger.add(os.path.join(output_dir, "create_dataset.log"))
|
| 405 |
|
| 406 |
# log parameters
|
| 407 |
logger.info(f"nb_samples: {nb_samples}")
|
|
@@ -561,21 +579,25 @@ if __name__ == "__main__":
|
|
| 561 |
parser.add_argument("--nb_samples", type=int, required=True, help="Number of base samples to generate")
|
| 562 |
parser.add_argument("--init_type", type=str, required=True, choices=["exhaustive", "cds", "palmer", "neh", "heuristics", "random"], help="Initialization type for the base samples")
|
| 563 |
parser.add_argument("--output_dir", type=str, required=True, help="Path to the output directory where dataset artifacts will be saved")
|
|
|
|
| 564 |
parser.add_argument("--seed", type=int, required=True, help="Random seed for reproducibility (set to None for no seeding)")
|
| 565 |
parser.add_argument("--normalize_makespans", type=bool, required=True, help="Whether to normalize makespans by the sum of processing times")
|
|
|
|
| 566 |
args = parser.parse_args()
|
| 567 |
-
|
| 568 |
-
# create pfsp_instance
|
| 569 |
-
pfsp_instance = generate_random_pfsp_instance(args.nb_jobs, args.nb_machines, args.time_min, args.time_max, seed=args.seed)
|
| 570 |
|
| 571 |
# create the dataset
|
| 572 |
create_dataset(
|
| 573 |
testing=args.testing,
|
| 574 |
-
pfsp_instance=pfsp_instance,
|
| 575 |
nb_samples=args.nb_samples,
|
| 576 |
init_type=args.init_type,
|
| 577 |
output_dir=args.output_dir,
|
| 578 |
seed=args.seed,
|
| 579 |
normalize_makespans=args.normalize_makespans,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 580 |
)
|
| 581 |
# ======
|
|
|
|
| 6 |
from loguru import logger
|
| 7 |
import tqdm
|
| 8 |
import time
|
| 9 |
+
import namer
|
| 10 |
|
| 11 |
|
| 12 |
def generate_random_pfsp_instance(nb_jobs, nb_machines, time_min, time_max, seed=97):
|
|
|
|
| 390 |
output_dir,
|
| 391 |
seed,
|
| 392 |
normalize_makespans,
|
| 393 |
+
nb_jobs,
|
| 394 |
+
nb_machines,
|
| 395 |
+
time_min,
|
| 396 |
+
time_max,
|
| 397 |
+
autoname_output_dir,
|
| 398 |
):
|
| 399 |
|
| 400 |
+
if autoname_output_dir:
|
| 401 |
+
output_dir = os.path.join(output_dir, time.strftime("%Y_%m_%d_%H_%M_%S") + "_" + namer.generate(separator="_", category="sports"))
|
| 402 |
+
|
| 403 |
+
# prepare loging
|
| 404 |
+
logger.add(os.path.join(output_dir, "create_dataset.log"))
|
| 405 |
+
|
| 406 |
+
if os.path.exists(pfsp_instance):
|
| 407 |
+
# TODO: add logic to load pfsp_instance from some file
|
| 408 |
+
pass
|
| 409 |
+
else:
|
| 410 |
+
# create pfsp_instance
|
| 411 |
+
logger.info(f"Creating pfsp_instance with {nb_jobs} jobs and {nb_machines} machines")
|
| 412 |
+
pfsp_instance = generate_random_pfsp_instance(nb_jobs, nb_machines, time_min, time_max, seed=seed)
|
| 413 |
+
|
| 414 |
# create the output folder
|
| 415 |
os.makedirs(output_dir, exist_ok=True)
|
| 416 |
|
| 417 |
# check if experiment termination flag file exists
|
| 418 |
if not testing:
|
| 419 |
+
if os.path.exists(os.path.join(output_dir, ".terminated_create_dataset")):
|
| 420 |
print("Dataset creation already done. Exiting...")
|
| 421 |
return None
|
| 422 |
|
|
|
|
|
|
|
| 423 |
|
| 424 |
# log parameters
|
| 425 |
logger.info(f"nb_samples: {nb_samples}")
|
|
|
|
| 579 |
parser.add_argument("--nb_samples", type=int, required=True, help="Number of base samples to generate")
|
| 580 |
parser.add_argument("--init_type", type=str, required=True, choices=["exhaustive", "cds", "palmer", "neh", "heuristics", "random"], help="Initialization type for the base samples")
|
| 581 |
parser.add_argument("--output_dir", type=str, required=True, help="Path to the output directory where dataset artifacts will be saved")
|
| 582 |
+
parser.add_argument("--autoname_output_dir", type=bool, required=True, help="Whether to autoname the output directory")
|
| 583 |
parser.add_argument("--seed", type=int, required=True, help="Random seed for reproducibility (set to None for no seeding)")
|
| 584 |
parser.add_argument("--normalize_makespans", type=bool, required=True, help="Whether to normalize makespans by the sum of processing times")
|
| 585 |
+
parser.add_argument("--pfsp_instance", type=str, required=True, help="Path to the pfsp instance or None if to be generated")
|
| 586 |
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
| 587 |
|
| 588 |
# create the dataset
|
| 589 |
create_dataset(
|
| 590 |
testing=args.testing,
|
|
|
|
| 591 |
nb_samples=args.nb_samples,
|
| 592 |
init_type=args.init_type,
|
| 593 |
output_dir=args.output_dir,
|
| 594 |
seed=args.seed,
|
| 595 |
normalize_makespans=args.normalize_makespans,
|
| 596 |
+
pfsp_instance=args.pfsp_instance,
|
| 597 |
+
nb_jobs=args.nb_jobs,
|
| 598 |
+
nb_machines=args.nb_machines,
|
| 599 |
+
time_min=args.time_min,
|
| 600 |
+
time_max=args.time_max,
|
| 601 |
+
autoname_output_dir=args.autoname_output_dir,
|
| 602 |
)
|
| 603 |
# ======
|
source/demos/ftd/create_dataset.log
CHANGED
|
@@ -5,3 +5,19 @@
|
|
| 5 |
2026-03-10 13:57:19.791 | INFO | __main__:create_dataset:415 - Normalizing makespans by the sum of processing times with pfsp sum: 7.829251766204834
|
| 6 |
2026-03-10 13:57:19.791 | INFO | __main__:create_dataset:422 - Exhaustive init_type: Number of samples: 5040
|
| 7 |
2026-03-10 13:57:19.834 | INFO | __main__:create_dataset:533 - Minimum makespan: 0.5417570471763611
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
2026-03-10 13:57:19.791 | INFO | __main__:create_dataset:415 - Normalizing makespans by the sum of processing times with pfsp sum: 7.829251766204834
|
| 6 |
2026-03-10 13:57:19.791 | INFO | __main__:create_dataset:422 - Exhaustive init_type: Number of samples: 5040
|
| 7 |
2026-03-10 13:57:19.834 | INFO | __main__:create_dataset:533 - Minimum makespan: 0.5417570471763611
|
| 8 |
+
2026-03-12 11:17:09.860 | INFO | __main__:create_dataset:418 - nb_samples: 0
|
| 9 |
+
2026-03-12 11:17:09.860 | INFO | __main__:create_dataset:419 - init_type: exhaustive
|
| 10 |
+
2026-03-12 11:17:09.860 | INFO | __main__:create_dataset:420 - output_dir: ./demos/ftd
|
| 11 |
+
2026-03-12 11:17:09.860 | INFO | __main__:create_dataset:421 - seed: 97
|
| 12 |
+
2026-03-12 11:19:00.900 | INFO | __main__:create_dataset:420 - nb_samples: 0
|
| 13 |
+
2026-03-12 11:19:00.900 | INFO | __main__:create_dataset:421 - init_type: exhaustive
|
| 14 |
+
2026-03-12 11:19:00.900 | INFO | __main__:create_dataset:422 - output_dir: ./demos/ftd
|
| 15 |
+
2026-03-12 11:19:00.900 | INFO | __main__:create_dataset:423 - seed: 97
|
| 16 |
+
2026-03-12 11:20:45.676 | INFO | __main__:create_dataset:405 - Creating pfsp_instance with 7 jobs and 2 machines
|
| 17 |
+
2026-03-12 11:20:45.703 | INFO | __main__:create_dataset:419 - nb_samples: 0
|
| 18 |
+
2026-03-12 11:20:45.703 | INFO | __main__:create_dataset:420 - init_type: exhaustive
|
| 19 |
+
2026-03-12 11:20:45.703 | INFO | __main__:create_dataset:421 - output_dir: ./demos/ftd
|
| 20 |
+
2026-03-12 11:20:45.703 | INFO | __main__:create_dataset:422 - seed: 97
|
| 21 |
+
2026-03-12 11:20:45.704 | INFO | __main__:create_dataset:427 - Normalizing makespans by the sum of processing times with pfsp sum: 7.829251766204834
|
| 22 |
+
2026-03-12 11:20:45.704 | INFO | __main__:create_dataset:434 - Exhaustive init_type: Number of samples: 5040
|
| 23 |
+
2026-03-12 11:20:45.751 | INFO | __main__:create_dataset:545 - Minimum makespan: 0.5417570471763611
|
source/demos/ftd/metadata.json
CHANGED
|
@@ -5,5 +5,5 @@
|
|
| 5 |
"init_type": "exhaustive",
|
| 6 |
"data_path": "./demos/ftd",
|
| 7 |
"seed": 97,
|
| 8 |
-
"date_time": "
|
| 9 |
}
|
|
|
|
| 5 |
"init_type": "exhaustive",
|
| 6 |
"data_path": "./demos/ftd",
|
| 7 |
"seed": 97,
|
| 8 |
+
"date_time": "2026_03_12_11_20_45"
|
| 9 |
}
|
source/demos/rs_artifacts/recover_schedules.log
CHANGED
|
@@ -1,19 +1,19 @@
|
|
| 1 |
-
2026-03-
|
| 2 |
-
2026-03-
|
| 3 |
-
2026-03-
|
| 4 |
-
2026-03-
|
| 5 |
-
2026-03-
|
| 6 |
-
2026-03-
|
| 7 |
-
2026-03-
|
| 8 |
-
2026-03-
|
| 9 |
-
2026-03-
|
| 10 |
-
2026-03-
|
| 11 |
-
2026-03-
|
| 12 |
-
2026-03-
|
| 13 |
-
2026-03-
|
| 14 |
-
2026-03-
|
| 15 |
-
2026-03-
|
| 16 |
-
2026-03-
|
| 17 |
-
2026-03-
|
| 18 |
-
2026-03-
|
| 19 |
-
2026-03-
|
|
|
|
| 1 |
+
2026-03-12 13:06:43.046 | INFO | __main__:<module>:395 - Found better makespan!:
|
| 2 |
+
2026-03-12 13:06:43.046 | INFO | __main__:<module>:396 - recovered permutation: [1 4 3 2 6 5 0]
|
| 3 |
+
2026-03-12 13:06:43.046 | INFO | __main__:<module>:397 - actual makespan: 5.0353
|
| 4 |
+
2026-03-12 13:06:43.047 | INFO | __main__:<module>:398 - actual makespan normalized: 0.6431
|
| 5 |
+
2026-03-12 13:06:43.047 | INFO | __main__:<module>:399 - predicted makespan (normalized): 0.6463
|
| 6 |
+
2026-03-12 13:06:43.257 | INFO | __main__:<module>:395 - Found better makespan!:
|
| 7 |
+
2026-03-12 13:06:43.257 | INFO | __main__:<module>:396 - recovered permutation: [6 1 3 4 0 5 2]
|
| 8 |
+
2026-03-12 13:06:43.258 | INFO | __main__:<module>:397 - actual makespan: 4.6095
|
| 9 |
+
2026-03-12 13:06:43.258 | INFO | __main__:<module>:398 - actual makespan normalized: 0.5888
|
| 10 |
+
2026-03-12 13:06:43.258 | INFO | __main__:<module>:399 - predicted makespan (normalized): 0.5881
|
| 11 |
+
2026-03-12 13:06:55.273 | INFO | __main__:<module>:395 - Found better makespan!:
|
| 12 |
+
2026-03-12 13:06:55.273 | INFO | __main__:<module>:396 - recovered permutation: [4 6 3 1 0 5 2]
|
| 13 |
+
2026-03-12 13:06:55.273 | INFO | __main__:<module>:397 - actual makespan: 4.3768
|
| 14 |
+
2026-03-12 13:06:55.273 | INFO | __main__:<module>:398 - actual makespan normalized: 0.5590
|
| 15 |
+
2026-03-12 13:06:55.273 | INFO | __main__:<module>:399 - predicted makespan (normalized): 0.5755
|
| 16 |
+
2026-03-12 13:07:23.978 | INFO | __main__:<module>:411 - NEH makespan: 0.5418
|
| 17 |
+
2026-03-12 13:07:23.978 | INFO | __main__:<module>:412 - CDS makespan: 0.5418
|
| 18 |
+
2026-03-12 13:07:23.978 | INFO | __main__:<module>:413 - Palmer makespan: 0.5595
|
| 19 |
+
2026-03-12 13:07:23.978 | INFO | __main__:<module>:414 - Best makespan found by optimization: 0.5590
|
source/demos/rs_artifacts/recover_schedules_pbar.log
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
Latent schedules optimization: 100%|████████████████████████
|
|
|
|
| 1 |
+
Latent schedules optimization: 100%|████████████████████████| 2000/2000 [00:41<00:00, 48.44it/s, total loss=1.18e+3, makespan=-0.0432, sinkhorn=592]
|
source/demos/train_artifacts/checkpoints/best_checkpoint.pth
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1272039
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a7f2a830b2b04168daacedd0b0c41ed034c98ec3be35611dd2a32c8cf6602353
|
| 3 |
size 1272039
|
source/demos/train_artifacts/checkpoints/last_checkpoint.pth
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1272039
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6500dac2061304bfd78de1d483743f4e57cc9830e68f30e148c9c9df5d483f62
|
| 3 |
size 1272039
|
source/demos/train_artifacts/train.log
CHANGED
|
@@ -1,41 +1,41 @@
|
|
| 1 |
-
2026-03-
|
| 2 |
-
2026-03-
|
| 3 |
-
2026-03-
|
| 4 |
-
2026-03-
|
| 5 |
-
2026-03-
|
| 6 |
-
2026-03-
|
| 7 |
-
2026-03-
|
| 8 |
-
2026-03-
|
| 9 |
-
2026-03-
|
| 10 |
-
2026-03-
|
| 11 |
-
2026-03-
|
| 12 |
-
2026-03-
|
| 13 |
-
2026-03-
|
| 14 |
-
2026-03-
|
| 15 |
-
2026-03-
|
| 16 |
-
2026-03-
|
| 17 |
-
2026-03-
|
| 18 |
-
2026-03-
|
| 19 |
-
2026-03-
|
| 20 |
-
2026-03-
|
| 21 |
-
2026-03-
|
| 22 |
-
2026-03-
|
| 23 |
-
2026-03-
|
| 24 |
-
2026-03-
|
| 25 |
-
2026-03-
|
| 26 |
-
2026-03-
|
| 27 |
-
2026-03-
|
| 28 |
-
2026-03-
|
| 29 |
-
2026-03-
|
| 30 |
-
2026-03-
|
| 31 |
-
2026-03-
|
| 32 |
-
2026-03-
|
| 33 |
-
2026-03-
|
| 34 |
-
2026-03-
|
| 35 |
-
2026-03-
|
| 36 |
-
2026-03-
|
| 37 |
-
2026-03-
|
| 38 |
-
2026-03-
|
| 39 |
-
2026-03-
|
| 40 |
-
2026-03-
|
| 41 |
-
2026-03-
|
|
|
|
| 1 |
+
2026-03-12 13:06:19.589 | INFO | __main__:__init__:288 - Loaded schedules from ./demos/ftd_processed/schedules_train.npy with shape (8568, 7)
|
| 2 |
+
2026-03-12 13:06:19.590 | INFO | __main__:__init__:293 - Loaded makespans from ./demos/ftd_processed/makespans_train.npy with shape (8568, 7)
|
| 3 |
+
2026-03-12 13:06:19.605 | INFO | __main__:train:321 - schedules.shape: torch.Size([64, 7])
|
| 4 |
+
2026-03-12 13:06:19.605 | INFO | __main__:train:322 - makespans.shape: torch.Size([64, 7])
|
| 5 |
+
2026-03-12 13:06:19.605 | INFO | __main__:train:388 - data_dir: ./demos/ftd_processed
|
| 6 |
+
2026-03-12 13:06:19.605 | INFO | __main__:train:389 - block_size: 7
|
| 7 |
+
2026-03-12 13:06:19.605 | INFO | __main__:train:390 - vocab_size: 7
|
| 8 |
+
2026-03-12 13:06:19.605 | INFO | __main__:train:391 - n_embd: 64
|
| 9 |
+
2026-03-12 13:06:19.605 | INFO | __main__:train:392 - n_head: 4
|
| 10 |
+
2026-03-12 13:06:19.605 | INFO | __main__:train:393 - n_layer: 2
|
| 11 |
+
2026-03-12 13:06:19.605 | INFO | __main__:train:394 - ff_width: 4
|
| 12 |
+
2026-03-12 13:06:19.605 | INFO | __main__:train:395 - train_batch_size: 64
|
| 13 |
+
2026-03-12 13:06:19.605 | INFO | __main__:train:396 - val_batch_size: 256
|
| 14 |
+
2026-03-12 13:06:19.605 | INFO | __main__:train:397 - dropout: 0.0
|
| 15 |
+
2026-03-12 13:06:19.606 | INFO | __main__:train:398 - nb_epochs: 5
|
| 16 |
+
2026-03-12 13:06:19.606 | INFO | __main__:train:399 - early_stopping_patience: 15
|
| 17 |
+
2026-03-12 13:06:19.606 | INFO | __main__:train:400 - nb_iters: 670
|
| 18 |
+
2026-03-12 13:06:19.606 | INFO | __main__:train:401 - checkpoint_interval: 33
|
| 19 |
+
2026-03-12 13:06:19.606 | INFO | __main__:train:402 - decay_lr: True
|
| 20 |
+
2026-03-12 13:06:19.606 | INFO | __main__:train:403 - lr_partitions_ratios: [0.66, None]
|
| 21 |
+
2026-03-12 13:06:19.606 | INFO | __main__:train:404 - lr_partitions_iters: [442, 228]
|
| 22 |
+
2026-03-12 13:06:19.606 | INFO | __main__:train:405 - init_lr: 0.0001
|
| 23 |
+
2026-03-12 13:06:19.606 | INFO | __main__:train:406 - max_lr: 0.001
|
| 24 |
+
2026-03-12 13:06:19.606 | INFO | __main__:train:407 - min_lr: 5e-05
|
| 25 |
+
2026-03-12 13:06:19.606 | INFO | __main__:train:408 - lr_warmup_iters_ratio: 0.1
|
| 26 |
+
2026-03-12 13:06:19.606 | INFO | __main__:train:409 - lr_decay_iters_ratio: 0.95
|
| 27 |
+
2026-03-12 13:06:19.606 | INFO | __main__:train:410 - beta1: 0.9
|
| 28 |
+
2026-03-12 13:06:19.606 | INFO | __main__:train:411 - beta2: 0.95
|
| 29 |
+
2026-03-12 13:06:19.606 | INFO | __main__:train:412 - weight_decay: 0.1
|
| 30 |
+
2026-03-12 13:06:19.606 | INFO | __main__:train:413 - grad_clip: 1.0
|
| 31 |
+
2026-03-12 13:06:19.606 | INFO | __main__:train:414 - compile: False
|
| 32 |
+
2026-03-12 13:06:19.606 | INFO | __main__:train:415 - compile_mode: default
|
| 33 |
+
2026-03-12 13:06:19.606 | INFO | __main__:train:416 - intermediate_schedules: True
|
| 34 |
+
2026-03-12 13:06:19.606 | INFO | __main__:train:417 - save_only_last_checkpoint: True
|
| 35 |
+
2026-03-12 13:06:19.703 | INFO | __main__:train:485 - The model has 100K trainable parameters
|
| 36 |
+
2026-03-12 13:06:19.703 | INFO | __main__:train:506 - num decayed parameter tensors: 40, with 99,712 parameters
|
| 37 |
+
2026-03-12 13:06:19.703 | INFO | __main__:train:507 - num non-decayed parameter tensors: 9, with 513 parameters
|
| 38 |
+
2026-03-12 13:06:19.704 | INFO | __main__:train:511 - using fused AdamW: True
|
| 39 |
+
2026-03-12 13:06:20.458 | INFO | __main__:__init__:288 - Loaded schedules from ./demos/ftd_processed/schedules_val.npy with shape (756, 7)
|
| 40 |
+
2026-03-12 13:06:20.458 | INFO | __main__:__init__:293 - Loaded makespans from ./demos/ftd_processed/makespans_val.npy with shape (756, 7)
|
| 41 |
+
2026-03-12 13:06:28.461 | INFO | __main__:train:698 - Best validation loss: 0.0003
|
source/demos/train_artifacts/train_pbar_epoch.log
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
Epoch 5/5: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████
|
|
|
|
| 1 |
+
Epoch 5/5: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 134/134 [00:01<00:00, 82.19it/s]
|
source/demos/train_artifacts/train_pbar_val.log
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
Validation 5.00: 100%|███████████████████████████████████████████████████████████████████████████████████████████████
|
|
|
|
| 1 |
+
Validation 5.00: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 232.26it/s]
|
source/launch_create_dataset.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from create_dataset import create_dataset
|
| 2 |
+
|
| 3 |
+
params = [
|
| 4 |
+
{
|
| 5 |
+
"testing": False,
|
| 6 |
+
"nb_jobs": nb_jobs,
|
| 7 |
+
"nb_machines": nb_machines,
|
| 8 |
+
"time_min": 0,
|
| 9 |
+
"time_max": 1,
|
| 10 |
+
"nb_samples": 0,
|
| 11 |
+
"init_type": "exhaustive",
|
| 12 |
+
"output_dir": f"../datasets/exhaustive_{nb_jobs}_{nb_machines}",
|
| 13 |
+
"seed": 97,
|
| 14 |
+
"normalize_makespans": True,
|
| 15 |
+
"pfsp_instance": "",
|
| 16 |
+
"autoname_output_dir": False,
|
| 17 |
+
}
|
| 18 |
+
for nb_jobs in range(7, 11)
|
| 19 |
+
for nb_machines in range(2, 7)
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
for param in params:
|
| 23 |
+
create_dataset(**param)
|
source/launch_process_dataset.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from process_dataset import process_dataset
|
| 2 |
+
|
| 3 |
+
params = [
|
| 4 |
+
{
|
| 5 |
+
"input_dir": f"../datasets/exhaustive_{nb_jobs}_{nb_machines}",
|
| 6 |
+
"output_dir": f"../datasets/exhaustive_{nb_jobs}_{nb_machines}/top_{k_eliminated}",
|
| 7 |
+
"train_ratio": 0.85,
|
| 8 |
+
"seed": 97,
|
| 9 |
+
"eliminate_top_k_makespans": k_eliminated,
|
| 10 |
+
"duplication_factor": 0.0,
|
| 11 |
+
}
|
| 12 |
+
for nb_jobs in range(7, 11)
|
| 13 |
+
for nb_machines in range(2, 7)
|
| 14 |
+
for k_eliminated in [0, 1, 2, 3, 4]
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
for param in params:
|
| 18 |
+
process_dataset(**param)
|
source/launch_train.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from train import train
|
| 2 |
+
|
| 3 |
+
params = [
|
| 4 |
+
{
|
| 5 |
+
"testing": False,
|
| 6 |
+
"seed": 97,
|
| 7 |
+
"data_dir": f"../datasets/exhaustive_{nb_jobs}_{nb_machines}/top_{top_k}",
|
| 8 |
+
"n_embd": 64,
|
| 9 |
+
"n_head": 4,
|
| 10 |
+
"n_layer": 2,
|
| 11 |
+
"ff_width": 4,
|
| 12 |
+
"intermediate_schedules": True,
|
| 13 |
+
"train_batch_size": 128,
|
| 14 |
+
"val_batch_size": 256,
|
| 15 |
+
"nb_epochs": 5,
|
| 16 |
+
"early_stopping_patience": 15,
|
| 17 |
+
"dropout": 0.0,
|
| 18 |
+
"checkpoint_interval_ratio": 1.0,
|
| 19 |
+
"decay_lr": True,
|
| 20 |
+
"lr_partitions_ratios": [0.66],
|
| 21 |
+
"init_lr": 1e-4,
|
| 22 |
+
"max_lr": 1e-3,
|
| 23 |
+
"min_lr": 5e-5,
|
| 24 |
+
"lr_warmup_iters_ratio": 0.1,
|
| 25 |
+
"lr_decay_iters_ratio": 0.95,
|
| 26 |
+
"beta1": 0.9,
|
| 27 |
+
"beta2": 0.95,
|
| 28 |
+
"weight_decay": 1e-1,
|
| 29 |
+
"grad_clip": 1.0,
|
| 30 |
+
"compile": "",
|
| 31 |
+
"compile_mode": "default",
|
| 32 |
+
"save_only_last_checkpoint": False,
|
| 33 |
+
"output_dir": f"../datasets/exhaustive_{nb_jobs}_{nb_machines}/top_{top_k}/train_Sm_Wd1e-1",
|
| 34 |
+
}
|
| 35 |
+
for nb_jobs in [7, 8, 9]
|
| 36 |
+
for nb_machines in [2, 3, 4, 5, 6]
|
| 37 |
+
for top_k in [0, 1, 2, 3, 4]
|
| 38 |
+
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
for param in params:
|
| 42 |
+
train(**param)
|
source/train.py
CHANGED
|
@@ -173,52 +173,48 @@ class GPT(nn.Module):
|
|
| 173 |
# ======
|
| 174 |
|
| 175 |
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
parser.add_argument("--output_dir", type=str, required=True)
|
| 210 |
-
args = parser.parse_args()
|
| 211 |
-
|
| 212 |
-
os.makedirs(args.output_dir, exist_ok=True)
|
| 213 |
|
| 214 |
# check if experiment termination flag file exists
|
| 215 |
-
if not
|
| 216 |
-
if os.path.exists(os.path.join(
|
| 217 |
print("Phase 1 already terminated. Exiting...")
|
| 218 |
-
|
| 219 |
# ======
|
| 220 |
-
if not os.path.exists(os.path.join(
|
| 221 |
-
shutil.copy("viz_train.ipynb",
|
| 222 |
# ======
|
| 223 |
else:
|
| 224 |
|
|
@@ -236,12 +232,12 @@ if __name__ == "__main__":
|
|
| 236 |
"viz_train.ipynb",
|
| 237 |
]
|
| 238 |
for f in files_to_delete:
|
| 239 |
-
f_path = os.path.join(
|
| 240 |
if os.path.exists(f_path): os.remove(f_path)
|
| 241 |
# ======
|
| 242 |
-
checkpoints_dir = os.path.join(
|
| 243 |
if os.path.exists(checkpoints_dir): shutil.rmtree(checkpoints_dir)
|
| 244 |
-
shutil.copy("viz_train.ipynb",
|
| 245 |
# ======
|
| 246 |
|
| 247 |
# check if GPU is available
|
|
@@ -249,31 +245,31 @@ if __name__ == "__main__":
|
|
| 249 |
device = "cuda"
|
| 250 |
|
| 251 |
# setup logging
|
| 252 |
-
loguru.logger.add(os.path.join(
|
| 253 |
|
| 254 |
# set random seeds
|
| 255 |
-
torch.manual_seed(
|
| 256 |
-
random.seed(
|
| 257 |
-
np.random.seed(
|
| 258 |
|
| 259 |
# setup model architecture parameters
|
| 260 |
-
with open(os.path.join(
|
| 261 |
metadata = json.load(f)
|
| 262 |
block_size = metadata["nb_jobs"] # context window size
|
| 263 |
vocab_size = metadata["nb_jobs"] # vocabulary size
|
| 264 |
-
n_embd =
|
| 265 |
-
n_head =
|
| 266 |
assert n_embd % n_head == 0
|
| 267 |
-
n_layer =
|
| 268 |
-
intermediate_schedules =
|
| 269 |
-
ff_width =
|
| 270 |
|
| 271 |
# setup training parameters and utils
|
| 272 |
-
train_batch_size =
|
| 273 |
-
val_batch_size =
|
| 274 |
-
nb_epochs =
|
| 275 |
-
early_stopping_patience =
|
| 276 |
-
dropout =
|
| 277 |
|
| 278 |
|
| 279 |
class FlowshopDataset(torch.utils.data.Dataset):
|
|
@@ -313,7 +309,7 @@ if __name__ == "__main__":
|
|
| 313 |
# ======
|
| 314 |
|
| 315 |
|
| 316 |
-
train_dataset = FlowshopDataset(
|
| 317 |
train_data_loader = torch.utils.data.DataLoader(
|
| 318 |
train_dataset,
|
| 319 |
batch_size=train_batch_size,
|
|
@@ -326,24 +322,24 @@ if __name__ == "__main__":
|
|
| 326 |
loguru.logger.info(f"makespans.shape: {makespans.shape}")
|
| 327 |
break
|
| 328 |
nb_iters = nb_epochs * len(train_data_loader)
|
| 329 |
-
checkpoint_interval = int(
|
| 330 |
-
decay_lr =
|
| 331 |
-
lr_partitions_ratios =
|
| 332 |
lr_partitions_iters = [int(r * nb_iters) for r in lr_partitions_ratios[:-1]]
|
| 333 |
lr_partitions_iters = lr_partitions_iters + [nb_iters - sum(lr_partitions_iters)]
|
| 334 |
assert sum(lr_partitions_iters) == nb_iters
|
| 335 |
-
init_lr =
|
| 336 |
-
max_lr =
|
| 337 |
-
min_lr =
|
| 338 |
-
lr_warmup_iters_ratio =
|
| 339 |
-
lr_decay_iters_ratio =
|
| 340 |
-
beta1 =
|
| 341 |
-
beta2 =
|
| 342 |
-
weight_decay =
|
| 343 |
-
grad_clip =
|
| 344 |
-
compile =
|
| 345 |
-
compile_mode =
|
| 346 |
-
save_only_last_checkpoint =
|
| 347 |
|
| 348 |
|
| 349 |
def human_readable(num):
|
|
@@ -389,7 +385,7 @@ if __name__ == "__main__":
|
|
| 389 |
|
| 390 |
|
| 391 |
# log parameters
|
| 392 |
-
loguru.logger.info(f"data_dir: {
|
| 393 |
loguru.logger.info(f"block_size: {block_size}")
|
| 394 |
loguru.logger.info(f"vocab_size: {vocab_size}")
|
| 395 |
loguru.logger.info(f"n_embd: {n_embd}")
|
|
@@ -421,9 +417,8 @@ if __name__ == "__main__":
|
|
| 421 |
loguru.logger.info(f"save_only_last_checkpoint: {save_only_last_checkpoint}")
|
| 422 |
|
| 423 |
# save parameters into a train_parameters.json
|
| 424 |
-
import json
|
| 425 |
train_params = {
|
| 426 |
-
"data_dir":
|
| 427 |
"block_size": block_size,
|
| 428 |
"vocab_size": vocab_size,
|
| 429 |
"n_embd": n_embd,
|
|
@@ -454,11 +449,11 @@ if __name__ == "__main__":
|
|
| 454 |
"intermediate_schedules": intermediate_schedules,
|
| 455 |
"save_only_last_checkpoint": save_only_last_checkpoint,
|
| 456 |
}
|
| 457 |
-
with open(os.path.join(
|
| 458 |
|
| 459 |
# load the last checkpoint if it exists, otherwise initialize the training from scratch
|
| 460 |
try:
|
| 461 |
-
last_checkpoint = torch.load(os.path.join(
|
| 462 |
start_epoch = last_checkpoint["epoch"]
|
| 463 |
start_epoch_iter = last_checkpoint["epoch_iter"] + 1
|
| 464 |
model_state_dict = last_checkpoint["model_state_dict"]
|
|
@@ -467,7 +462,7 @@ if __name__ == "__main__":
|
|
| 467 |
patience_counter = last_checkpoint["patience_counter"]
|
| 468 |
improved_this_epoch = last_checkpoint["improved_this_epoch"]
|
| 469 |
except FileNotFoundError:
|
| 470 |
-
os.makedirs(os.path.join(
|
| 471 |
start_epoch = 0
|
| 472 |
start_epoch_iter = 0
|
| 473 |
model_state_dict = None
|
|
@@ -524,10 +519,10 @@ if __name__ == "__main__":
|
|
| 524 |
torch.set_float32_matmul_precision("high")
|
| 525 |
|
| 526 |
# initialize the np memmap array to save the batch losses
|
| 527 |
-
batch_losses_path = os.path.join(
|
| 528 |
-
last_batch_loss_idx_path = os.path.join(
|
| 529 |
-
val_losses_path = os.path.join(
|
| 530 |
-
last_val_loss_idx_path = os.path.join(
|
| 531 |
|
| 532 |
try:
|
| 533 |
batch_losses = np.lib.format.open_memmap(batch_losses_path, mode="r+", dtype=np.float32, shape=(nb_iters,))
|
|
@@ -546,7 +541,7 @@ if __name__ == "__main__":
|
|
| 546 |
|
| 547 |
# create data_loader for validation
|
| 548 |
val_data_loader = torch.utils.data.DataLoader(
|
| 549 |
-
FlowshopDataset(
|
| 550 |
batch_size=val_batch_size,
|
| 551 |
shuffle=False,
|
| 552 |
)
|
|
@@ -560,7 +555,7 @@ if __name__ == "__main__":
|
|
| 560 |
# implement the logic to resume after failure
|
| 561 |
## create the generator, sampler, data loader
|
| 562 |
generator = torch.Generator()
|
| 563 |
-
generator.manual_seed(
|
| 564 |
train_sampler = torch.utils.data.RandomSampler(
|
| 565 |
train_dataset,
|
| 566 |
generator=generator
|
|
@@ -586,7 +581,7 @@ if __name__ == "__main__":
|
|
| 586 |
initial=start_epoch_iter,
|
| 587 |
desc=f"Epoch {epoch+1}/{nb_epochs}",
|
| 588 |
)):
|
| 589 |
-
with open(os.path.join(
|
| 590 |
|
| 591 |
# move the batch to the device
|
| 592 |
schedules_batch = schedules_batch.to(device)
|
|
@@ -629,7 +624,7 @@ if __name__ == "__main__":
|
|
| 629 |
val_data_loader,
|
| 630 |
desc=f"Validation {epoch+(epoch_iter+1)/len(train_data_loader):.2f}",
|
| 631 |
)):
|
| 632 |
-
with open(os.path.join(
|
| 633 |
|
| 634 |
# move the batch to the device
|
| 635 |
schedules_batch = schedules_batch.to(device)
|
|
@@ -640,7 +635,7 @@ if __name__ == "__main__":
|
|
| 640 |
makespans, loss = train_model(schedules_batch, makespans_batch)
|
| 641 |
total_val_loss += loss.item() * schedules_batch.size(0)
|
| 642 |
# ======
|
| 643 |
-
with open(os.path.join(
|
| 644 |
|
| 645 |
# compute the total validation loss (averaging over the dataset)
|
| 646 |
total_val_loss /= len(val_data_loader.dataset)
|
|
@@ -669,22 +664,22 @@ if __name__ == "__main__":
|
|
| 669 |
}
|
| 670 |
torch.save(
|
| 671 |
checkpoint,
|
| 672 |
-
os.path.join(
|
| 673 |
)
|
| 674 |
if not save_only_last_checkpoint:
|
| 675 |
torch.save(
|
| 676 |
checkpoint,
|
| 677 |
-
os.path.join(
|
| 678 |
)
|
| 679 |
if best_val_loss == total_val_loss:
|
| 680 |
torch.save(
|
| 681 |
checkpoint,
|
| 682 |
-
os.path.join(
|
| 683 |
)
|
| 684 |
# ======
|
| 685 |
# ======
|
| 686 |
# ======
|
| 687 |
-
with open(os.path.join(
|
| 688 |
|
| 689 |
# set the start_epoch_iter to 0 for the next epoch
|
| 690 |
start_epoch_iter = 0
|
|
@@ -703,7 +698,77 @@ if __name__ == "__main__":
|
|
| 703 |
loguru.logger.info(f"Best validation loss: {best_val_loss:.4f}")
|
| 704 |
|
| 705 |
# create experiment termination flag file
|
| 706 |
-
with open(os.path.join(
|
| 707 |
pass
|
| 708 |
# ======
|
| 709 |
-
# ======
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
# ======
|
| 174 |
|
| 175 |
|
| 176 |
+
def train(
|
| 177 |
+
testing: bool,
|
| 178 |
+
seed: int,
|
| 179 |
+
data_dir: str,
|
| 180 |
+
n_embd: int,
|
| 181 |
+
n_head: int,
|
| 182 |
+
n_layer: int,
|
| 183 |
+
intermediate_schedules: bool,
|
| 184 |
+
dropout: float,
|
| 185 |
+
ff_width: int,
|
| 186 |
+
train_batch_size: int,
|
| 187 |
+
val_batch_size: int,
|
| 188 |
+
nb_epochs: int,
|
| 189 |
+
early_stopping_patience: int,
|
| 190 |
+
checkpoint_interval_ratio: float,
|
| 191 |
+
decay_lr: bool,
|
| 192 |
+
lr_partitions_ratios: list[float],
|
| 193 |
+
init_lr: float,
|
| 194 |
+
max_lr: float,
|
| 195 |
+
min_lr: float,
|
| 196 |
+
lr_warmup_iters_ratio: float,
|
| 197 |
+
lr_decay_iters_ratio: float,
|
| 198 |
+
beta1: float,
|
| 199 |
+
beta2: float,
|
| 200 |
+
weight_decay: float,
|
| 201 |
+
grad_clip: float,
|
| 202 |
+
compile: bool,
|
| 203 |
+
compile_mode: str,
|
| 204 |
+
save_only_last_checkpoint: bool,
|
| 205 |
+
output_dir: str,
|
| 206 |
+
):
|
| 207 |
+
|
| 208 |
+
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
# check if experiment termination flag file exists
|
| 211 |
+
if not testing:
|
| 212 |
+
if os.path.exists(os.path.join(output_dir, ".terminated_phase1")):
|
| 213 |
print("Phase 1 already terminated. Exiting...")
|
| 214 |
+
return
|
| 215 |
# ======
|
| 216 |
+
if not os.path.exists(os.path.join(output_dir, "viz_train.ipynb")):
|
| 217 |
+
shutil.copy("viz_train.ipynb", output_dir)
|
| 218 |
# ======
|
| 219 |
else:
|
| 220 |
|
|
|
|
| 232 |
"viz_train.ipynb",
|
| 233 |
]
|
| 234 |
for f in files_to_delete:
|
| 235 |
+
f_path = os.path.join(output_dir, f)
|
| 236 |
if os.path.exists(f_path): os.remove(f_path)
|
| 237 |
# ======
|
| 238 |
+
checkpoints_dir = os.path.join(output_dir, "checkpoints")
|
| 239 |
if os.path.exists(checkpoints_dir): shutil.rmtree(checkpoints_dir)
|
| 240 |
+
shutil.copy("viz_train.ipynb", output_dir)
|
| 241 |
# ======
|
| 242 |
|
| 243 |
# check if GPU is available
|
|
|
|
| 245 |
device = "cuda"
|
| 246 |
|
| 247 |
# setup logging
|
| 248 |
+
loguru.logger.add(os.path.join(output_dir, "train.log"))
|
| 249 |
|
| 250 |
# set random seeds
|
| 251 |
+
torch.manual_seed(seed)
|
| 252 |
+
random.seed(seed)
|
| 253 |
+
np.random.seed(seed)
|
| 254 |
|
| 255 |
# setup model architecture parameters
|
| 256 |
+
with open(os.path.join(data_dir, "metadata.json"), "r") as f:
|
| 257 |
metadata = json.load(f)
|
| 258 |
block_size = metadata["nb_jobs"] # context window size
|
| 259 |
vocab_size = metadata["nb_jobs"] # vocabulary size
|
| 260 |
+
n_embd = n_embd # embedding dimension
|
| 261 |
+
n_head = n_head # number of attention heads
|
| 262 |
assert n_embd % n_head == 0
|
| 263 |
+
n_layer = n_layer # number of transformer blocks
|
| 264 |
+
intermediate_schedules = intermediate_schedules
|
| 265 |
+
ff_width = ff_width
|
| 266 |
|
| 267 |
# setup training parameters and utils
|
| 268 |
+
train_batch_size = train_batch_size # batch size for training
|
| 269 |
+
val_batch_size = val_batch_size # batch size for validation
|
| 270 |
+
nb_epochs = nb_epochs # number of pseudo-epochs to train for
|
| 271 |
+
early_stopping_patience = early_stopping_patience # number of epochs without improvement to trigger early stopping
|
| 272 |
+
dropout = dropout
|
| 273 |
|
| 274 |
|
| 275 |
class FlowshopDataset(torch.utils.data.Dataset):
|
|
|
|
| 309 |
# ======
|
| 310 |
|
| 311 |
|
| 312 |
+
train_dataset = FlowshopDataset(data_dir, split="train", load_in_memory=False)
|
| 313 |
train_data_loader = torch.utils.data.DataLoader(
|
| 314 |
train_dataset,
|
| 315 |
batch_size=train_batch_size,
|
|
|
|
| 322 |
loguru.logger.info(f"makespans.shape: {makespans.shape}")
|
| 323 |
break
|
| 324 |
nb_iters = nb_epochs * len(train_data_loader)
|
| 325 |
+
checkpoint_interval = int(checkpoint_interval_ratio * len(train_data_loader))
|
| 326 |
+
decay_lr = decay_lr
|
| 327 |
+
lr_partitions_ratios = lr_partitions_ratios + [None]
|
| 328 |
lr_partitions_iters = [int(r * nb_iters) for r in lr_partitions_ratios[:-1]]
|
| 329 |
lr_partitions_iters = lr_partitions_iters + [nb_iters - sum(lr_partitions_iters)]
|
| 330 |
assert sum(lr_partitions_iters) == nb_iters
|
| 331 |
+
init_lr = init_lr #1e-4
|
| 332 |
+
max_lr = max_lr #1e-3
|
| 333 |
+
min_lr = min_lr #5*1e-5
|
| 334 |
+
lr_warmup_iters_ratio = lr_warmup_iters_ratio #0.1
|
| 335 |
+
lr_decay_iters_ratio = lr_decay_iters_ratio #0.95
|
| 336 |
+
beta1 = beta1 # Adam beta1
|
| 337 |
+
beta2 = beta2 # Adam beta2
|
| 338 |
+
weight_decay = weight_decay # 1e-1 # weight decay
|
| 339 |
+
grad_clip = grad_clip # 1.0 # gradient clipping value
|
| 340 |
+
compile = compile
|
| 341 |
+
compile_mode = compile_mode
|
| 342 |
+
save_only_last_checkpoint = save_only_last_checkpoint
|
| 343 |
|
| 344 |
|
| 345 |
def human_readable(num):
|
|
|
|
| 385 |
|
| 386 |
|
| 387 |
# log parameters
|
| 388 |
+
loguru.logger.info(f"data_dir: {data_dir}")
|
| 389 |
loguru.logger.info(f"block_size: {block_size}")
|
| 390 |
loguru.logger.info(f"vocab_size: {vocab_size}")
|
| 391 |
loguru.logger.info(f"n_embd: {n_embd}")
|
|
|
|
| 417 |
loguru.logger.info(f"save_only_last_checkpoint: {save_only_last_checkpoint}")
|
| 418 |
|
| 419 |
# save parameters into a train_parameters.json
|
|
|
|
| 420 |
train_params = {
|
| 421 |
+
"data_dir": data_dir,
|
| 422 |
"block_size": block_size,
|
| 423 |
"vocab_size": vocab_size,
|
| 424 |
"n_embd": n_embd,
|
|
|
|
| 449 |
"intermediate_schedules": intermediate_schedules,
|
| 450 |
"save_only_last_checkpoint": save_only_last_checkpoint,
|
| 451 |
}
|
| 452 |
+
with open(os.path.join(output_dir, "train_parameters.json"), "w") as f: json.dump(train_params, f, indent=4)
|
| 453 |
|
| 454 |
# load the last checkpoint if it exists, otherwise initialize the training from scratch
|
| 455 |
try:
|
| 456 |
+
last_checkpoint = torch.load(os.path.join(output_dir, "checkpoints", "last_checkpoint.pth"))
|
| 457 |
start_epoch = last_checkpoint["epoch"]
|
| 458 |
start_epoch_iter = last_checkpoint["epoch_iter"] + 1
|
| 459 |
model_state_dict = last_checkpoint["model_state_dict"]
|
|
|
|
| 462 |
patience_counter = last_checkpoint["patience_counter"]
|
| 463 |
improved_this_epoch = last_checkpoint["improved_this_epoch"]
|
| 464 |
except FileNotFoundError:
|
| 465 |
+
os.makedirs(os.path.join(output_dir, "checkpoints"), exist_ok=True)
|
| 466 |
start_epoch = 0
|
| 467 |
start_epoch_iter = 0
|
| 468 |
model_state_dict = None
|
|
|
|
| 519 |
torch.set_float32_matmul_precision("high")
|
| 520 |
|
| 521 |
# initialize the np memmap array to save the batch losses
|
| 522 |
+
batch_losses_path = os.path.join(output_dir, "batch_losses.npy")
|
| 523 |
+
last_batch_loss_idx_path = os.path.join(output_dir, "last_batch_loss_idx.npy")
|
| 524 |
+
val_losses_path = os.path.join(output_dir, "val_losses.npy")
|
| 525 |
+
last_val_loss_idx_path = os.path.join(output_dir, "last_val_loss_idx.npy")
|
| 526 |
|
| 527 |
try:
|
| 528 |
batch_losses = np.lib.format.open_memmap(batch_losses_path, mode="r+", dtype=np.float32, shape=(nb_iters,))
|
|
|
|
| 541 |
|
| 542 |
# create data_loader for validation
|
| 543 |
val_data_loader = torch.utils.data.DataLoader(
|
| 544 |
+
FlowshopDataset(data_dir, split="val", load_in_memory=True),
|
| 545 |
batch_size=val_batch_size,
|
| 546 |
shuffle=False,
|
| 547 |
)
|
|
|
|
| 555 |
# implement the logic to resume after failure
|
| 556 |
## create the generator, sampler, data loader
|
| 557 |
generator = torch.Generator()
|
| 558 |
+
generator.manual_seed(seed + epoch)
|
| 559 |
train_sampler = torch.utils.data.RandomSampler(
|
| 560 |
train_dataset,
|
| 561 |
generator=generator
|
|
|
|
| 581 |
initial=start_epoch_iter,
|
| 582 |
desc=f"Epoch {epoch+1}/{nb_epochs}",
|
| 583 |
)):
|
| 584 |
+
with open(os.path.join(output_dir, "train_pbar_epoch.log"), "w") as f: f.write(str(pbar))
|
| 585 |
|
| 586 |
# move the batch to the device
|
| 587 |
schedules_batch = schedules_batch.to(device)
|
|
|
|
| 624 |
val_data_loader,
|
| 625 |
desc=f"Validation {epoch+(epoch_iter+1)/len(train_data_loader):.2f}",
|
| 626 |
)):
|
| 627 |
+
with open(os.path.join(output_dir, "train_pbar_val.log"), "w") as f: f.write(str(pbar2))
|
| 628 |
|
| 629 |
# move the batch to the device
|
| 630 |
schedules_batch = schedules_batch.to(device)
|
|
|
|
| 635 |
makespans, loss = train_model(schedules_batch, makespans_batch)
|
| 636 |
total_val_loss += loss.item() * schedules_batch.size(0)
|
| 637 |
# ======
|
| 638 |
+
with open(os.path.join(output_dir, "train_pbar_val.log"), "w") as f: f.write(str(pbar2))
|
| 639 |
|
| 640 |
# compute the total validation loss (averaging over the dataset)
|
| 641 |
total_val_loss /= len(val_data_loader.dataset)
|
|
|
|
| 664 |
}
|
| 665 |
torch.save(
|
| 666 |
checkpoint,
|
| 667 |
+
os.path.join(output_dir, "checkpoints", "last_checkpoint.pth")
|
| 668 |
)
|
| 669 |
if not save_only_last_checkpoint:
|
| 670 |
torch.save(
|
| 671 |
checkpoint,
|
| 672 |
+
os.path.join(output_dir, "checkpoints", f"checkpoint_epoch_{epoch+(epoch_iter+1)/len(train_data_loader):.2f}.pth")
|
| 673 |
)
|
| 674 |
if best_val_loss == total_val_loss:
|
| 675 |
torch.save(
|
| 676 |
checkpoint,
|
| 677 |
+
os.path.join(output_dir, "checkpoints", "best_checkpoint.pth")
|
| 678 |
)
|
| 679 |
# ======
|
| 680 |
# ======
|
| 681 |
# ======
|
| 682 |
+
with open(os.path.join(output_dir, "train_pbar_epoch.log"), "w") as f: f.write(str(pbar))
|
| 683 |
|
| 684 |
# set the start_epoch_iter to 0 for the next epoch
|
| 685 |
start_epoch_iter = 0
|
|
|
|
| 698 |
loguru.logger.info(f"Best validation loss: {best_val_loss:.4f}")
|
| 699 |
|
| 700 |
# create experiment termination flag file
|
| 701 |
+
with open(os.path.join(output_dir, ".terminated_phase1"), "w") as f:
|
| 702 |
pass
|
| 703 |
# ======
|
| 704 |
+
# ======
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
if __name__ == "__main__":
|
| 708 |
+
|
| 709 |
+
# parse arguments
|
| 710 |
+
from argparse import ArgumentParser
|
| 711 |
+
parser = ArgumentParser()
|
| 712 |
+
parser.add_argument("--testing", type=bool, required=True)
|
| 713 |
+
parser.add_argument("--seed", type=int, required=True)
|
| 714 |
+
parser.add_argument("--data_dir", type=str, required=True)
|
| 715 |
+
parser.add_argument("--n_embd", type=int, required=True)
|
| 716 |
+
parser.add_argument("--n_head", type=int, required=True)
|
| 717 |
+
parser.add_argument("--n_layer", type=int, required=True)
|
| 718 |
+
parser.add_argument("--intermediate_schedules", type=bool, required=True)
|
| 719 |
+
parser.add_argument("--dropout", type=float, required=True)
|
| 720 |
+
parser.add_argument("--ff_width", type=int, required=True)
|
| 721 |
+
parser.add_argument("--train_batch_size", type=int, required=True)
|
| 722 |
+
parser.add_argument("--val_batch_size", type=int, required=True)
|
| 723 |
+
parser.add_argument("--nb_epochs", type=int, required=True)
|
| 724 |
+
parser.add_argument("--early_stopping_patience", type=int, required=True)
|
| 725 |
+
parser.add_argument("--checkpoint_interval_ratio", type=float, required=True)
|
| 726 |
+
parser.add_argument("--decay_lr", type=bool, required=True)
|
| 727 |
+
parser.add_argument("--lr_partitions_ratios", type=lambda s: [float(item) for item in s.split(',')], help='Comma-separated list of floats that do not add up to 1 (e.g., 0.1,0.5,1)', required=True)
|
| 728 |
+
parser.add_argument("--init_lr", type=float, required=True)
|
| 729 |
+
parser.add_argument("--max_lr", type=float, required=True)
|
| 730 |
+
parser.add_argument("--min_lr", type=float, required=True)
|
| 731 |
+
parser.add_argument("--lr_warmup_iters_ratio", type=float, required=True)
|
| 732 |
+
parser.add_argument("--lr_decay_iters_ratio", type=float, required=True)
|
| 733 |
+
parser.add_argument("--beta1", type=float, required=True)
|
| 734 |
+
parser.add_argument("--beta2", type=float, required=True)
|
| 735 |
+
parser.add_argument("--weight_decay", type=float, required=True)
|
| 736 |
+
parser.add_argument("--grad_clip", type=float, required=True)
|
| 737 |
+
parser.add_argument("--compile", type=bool, required=True)
|
| 738 |
+
parser.add_argument("--compile_mode", type=str, required=True)
|
| 739 |
+
parser.add_argument("--save_only_last_checkpoint", type=bool, required=True)
|
| 740 |
+
parser.add_argument("--output_dir", type=str, required=True)
|
| 741 |
+
args = parser.parse_args()
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
train(
|
| 745 |
+
testing=args.testing,
|
| 746 |
+
seed=args.seed,
|
| 747 |
+
data_dir=args.data_dir,
|
| 748 |
+
n_embd=args.n_embd,
|
| 749 |
+
n_head=args.n_head,
|
| 750 |
+
n_layer=args.n_layer,
|
| 751 |
+
intermediate_schedules=args.intermediate_schedules,
|
| 752 |
+
dropout=args.dropout,
|
| 753 |
+
ff_width=args.ff_width,
|
| 754 |
+
train_batch_size=args.train_batch_size,
|
| 755 |
+
val_batch_size=args.val_batch_size,
|
| 756 |
+
nb_epochs=args.nb_epochs,
|
| 757 |
+
early_stopping_patience=args.early_stopping_patience,
|
| 758 |
+
checkpoint_interval_ratio=args.checkpoint_interval_ratio,
|
| 759 |
+
decay_lr=args.decay_lr,
|
| 760 |
+
lr_partitions_ratios=args.lr_partitions_ratios,
|
| 761 |
+
init_lr=args.init_lr,
|
| 762 |
+
max_lr=args.max_lr,
|
| 763 |
+
min_lr=args.min_lr,
|
| 764 |
+
lr_warmup_iters_ratio=args.lr_warmup_iters_ratio,
|
| 765 |
+
lr_decay_iters_ratio=args.lr_decay_iters_ratio,
|
| 766 |
+
beta1=args.beta1,
|
| 767 |
+
beta2=args.beta2,
|
| 768 |
+
weight_decay=args.weight_decay,
|
| 769 |
+
grad_clip=args.grad_clip,
|
| 770 |
+
compile=args.compile,
|
| 771 |
+
compile_mode=args.compile_mode,
|
| 772 |
+
save_only_last_checkpoint=args.save_only_last_checkpoint,
|
| 773 |
+
output_dir=args.output_dir,
|
| 774 |
+
)
|