younadi commited on
Commit
6179bb6
·
1 Parent(s): ff1ceb8
source/.ipynb CHANGED
@@ -3,6 +3,28 @@
3
  {
4
  "cell_type": "code",
5
  "execution_count": 2,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  "id": "ee46ab38",
7
  "metadata": {},
8
  "outputs": [
 
3
  {
4
  "cell_type": "code",
5
  "execution_count": 2,
6
+ "id": "6ff2d58c",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "data": {
11
+ "text/plain": [
12
+ "False"
13
+ ]
14
+ },
15
+ "execution_count": 2,
16
+ "metadata": {},
17
+ "output_type": "execute_result"
18
+ }
19
+ ],
20
+ "source": [
21
+ "import os\n",
22
+ "os.path.exists(\"\")"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": null,
28
  "id": "ee46ab38",
29
  "metadata": {},
30
  "outputs": [
source/create_dataset.bash CHANGED
@@ -1,5 +1,5 @@
1
  python create_dataset.py\
2
- --testing True\
3
  --nb_jobs 7\
4
  --nb_machines 2\
5
  --time_min 0\
@@ -8,4 +8,6 @@ python create_dataset.py\
8
  --init_type exhaustive\
9
  --output_dir "./demos/ftd"\
10
  --seed 97\
11
- --normalize_makespans "true"\
 
 
 
1
  python create_dataset.py\
2
+ --testing False\
3
  --nb_jobs 7\
4
  --nb_machines 2\
5
  --time_min 0\
 
8
  --init_type exhaustive\
9
  --output_dir "./demos/ftd"\
10
  --seed 97\
11
+ --normalize_makespans "true"\
12
+ --pfsp_instance ""\
13
+ --autoname_output_dir ""\
source/create_dataset.py CHANGED
@@ -6,6 +6,7 @@ import math
6
  from loguru import logger
7
  import tqdm
8
  import time
 
9
 
10
 
11
  def generate_random_pfsp_instance(nb_jobs, nb_machines, time_min, time_max, seed=97):
@@ -389,19 +390,36 @@ def create_dataset(
389
  output_dir,
390
  seed,
391
  normalize_makespans,
 
 
 
 
 
392
  ):
393
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
  # create the output folder
395
  os.makedirs(output_dir, exist_ok=True)
396
 
397
  # check if experiment termination flag file exists
398
  if not testing:
399
- if os.path.exists(os.path.join(args.output_dir, ".terminated_create_dataset")):
400
  print("Dataset creation already done. Exiting...")
401
  return None
402
 
403
- # prepare loging
404
- logger.add(os.path.join(output_dir, "create_dataset.log"))
405
 
406
  # log parameters
407
  logger.info(f"nb_samples: {nb_samples}")
@@ -561,21 +579,25 @@ if __name__ == "__main__":
561
  parser.add_argument("--nb_samples", type=int, required=True, help="Number of base samples to generate")
562
  parser.add_argument("--init_type", type=str, required=True, choices=["exhaustive", "cds", "palmer", "neh", "heuristics", "random"], help="Initialization type for the base samples")
563
  parser.add_argument("--output_dir", type=str, required=True, help="Path to the output directory where dataset artifacts will be saved")
 
564
  parser.add_argument("--seed", type=int, required=True, help="Random seed for reproducibility (set to None for no seeding)")
565
  parser.add_argument("--normalize_makespans", type=bool, required=True, help="Whether to normalize makespans by the sum of processing times")
 
566
  args = parser.parse_args()
567
-
568
- # create pfsp_instance
569
- pfsp_instance = generate_random_pfsp_instance(args.nb_jobs, args.nb_machines, args.time_min, args.time_max, seed=args.seed)
570
 
571
  # create the dataset
572
  create_dataset(
573
  testing=args.testing,
574
- pfsp_instance=pfsp_instance,
575
  nb_samples=args.nb_samples,
576
  init_type=args.init_type,
577
  output_dir=args.output_dir,
578
  seed=args.seed,
579
  normalize_makespans=args.normalize_makespans,
 
 
 
 
 
 
580
  )
581
  # ======
 
6
  from loguru import logger
7
  import tqdm
8
  import time
9
+ import namer
10
 
11
 
12
  def generate_random_pfsp_instance(nb_jobs, nb_machines, time_min, time_max, seed=97):
 
390
  output_dir,
391
  seed,
392
  normalize_makespans,
393
+ nb_jobs,
394
+ nb_machines,
395
+ time_min,
396
+ time_max,
397
+ autoname_output_dir,
398
  ):
399
 
400
+ if autoname_output_dir:
401
+ output_dir = os.path.join(output_dir, time.strftime("%Y_%m_%d_%H_%M_%S") + "_" + namer.generate(separator="_", category="sports"))
402
+
403
+ # prepare loging
404
+ logger.add(os.path.join(output_dir, "create_dataset.log"))
405
+
406
+ if os.path.exists(pfsp_instance):
407
+ # TODO: add logic to load pfsp_instance from some file
408
+ pass
409
+ else:
410
+ # create pfsp_instance
411
+ logger.info(f"Creating pfsp_instance with {nb_jobs} jobs and {nb_machines} machines")
412
+ pfsp_instance = generate_random_pfsp_instance(nb_jobs, nb_machines, time_min, time_max, seed=seed)
413
+
414
  # create the output folder
415
  os.makedirs(output_dir, exist_ok=True)
416
 
417
  # check if experiment termination flag file exists
418
  if not testing:
419
+ if os.path.exists(os.path.join(output_dir, ".terminated_create_dataset")):
420
  print("Dataset creation already done. Exiting...")
421
  return None
422
 
 
 
423
 
424
  # log parameters
425
  logger.info(f"nb_samples: {nb_samples}")
 
579
  parser.add_argument("--nb_samples", type=int, required=True, help="Number of base samples to generate")
580
  parser.add_argument("--init_type", type=str, required=True, choices=["exhaustive", "cds", "palmer", "neh", "heuristics", "random"], help="Initialization type for the base samples")
581
  parser.add_argument("--output_dir", type=str, required=True, help="Path to the output directory where dataset artifacts will be saved")
582
+ parser.add_argument("--autoname_output_dir", type=bool, required=True, help="Whether to autoname the output directory")
583
  parser.add_argument("--seed", type=int, required=True, help="Random seed for reproducibility (set to None for no seeding)")
584
  parser.add_argument("--normalize_makespans", type=bool, required=True, help="Whether to normalize makespans by the sum of processing times")
585
+ parser.add_argument("--pfsp_instance", type=str, required=True, help="Path to the pfsp instance or None if to be generated")
586
  args = parser.parse_args()
 
 
 
587
 
588
  # create the dataset
589
  create_dataset(
590
  testing=args.testing,
 
591
  nb_samples=args.nb_samples,
592
  init_type=args.init_type,
593
  output_dir=args.output_dir,
594
  seed=args.seed,
595
  normalize_makespans=args.normalize_makespans,
596
+ pfsp_instance=args.pfsp_instance,
597
+ nb_jobs=args.nb_jobs,
598
+ nb_machines=args.nb_machines,
599
+ time_min=args.time_min,
600
+ time_max=args.time_max,
601
+ autoname_output_dir=args.autoname_output_dir,
602
  )
603
  # ======
source/demos/ftd/create_dataset.log CHANGED
@@ -5,3 +5,19 @@
5
  2026-03-10 13:57:19.791 | INFO | __main__:create_dataset:415 - Normalizing makespans by the sum of processing times with pfsp sum: 7.829251766204834
6
  2026-03-10 13:57:19.791 | INFO | __main__:create_dataset:422 - Exhaustive init_type: Number of samples: 5040
7
  2026-03-10 13:57:19.834 | INFO | __main__:create_dataset:533 - Minimum makespan: 0.5417570471763611
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  2026-03-10 13:57:19.791 | INFO | __main__:create_dataset:415 - Normalizing makespans by the sum of processing times with pfsp sum: 7.829251766204834
6
  2026-03-10 13:57:19.791 | INFO | __main__:create_dataset:422 - Exhaustive init_type: Number of samples: 5040
7
  2026-03-10 13:57:19.834 | INFO | __main__:create_dataset:533 - Minimum makespan: 0.5417570471763611
8
+ 2026-03-12 11:17:09.860 | INFO | __main__:create_dataset:418 - nb_samples: 0
9
+ 2026-03-12 11:17:09.860 | INFO | __main__:create_dataset:419 - init_type: exhaustive
10
+ 2026-03-12 11:17:09.860 | INFO | __main__:create_dataset:420 - output_dir: ./demos/ftd
11
+ 2026-03-12 11:17:09.860 | INFO | __main__:create_dataset:421 - seed: 97
12
+ 2026-03-12 11:19:00.900 | INFO | __main__:create_dataset:420 - nb_samples: 0
13
+ 2026-03-12 11:19:00.900 | INFO | __main__:create_dataset:421 - init_type: exhaustive
14
+ 2026-03-12 11:19:00.900 | INFO | __main__:create_dataset:422 - output_dir: ./demos/ftd
15
+ 2026-03-12 11:19:00.900 | INFO | __main__:create_dataset:423 - seed: 97
16
+ 2026-03-12 11:20:45.676 | INFO | __main__:create_dataset:405 - Creating pfsp_instance with 7 jobs and 2 machines
17
+ 2026-03-12 11:20:45.703 | INFO | __main__:create_dataset:419 - nb_samples: 0
18
+ 2026-03-12 11:20:45.703 | INFO | __main__:create_dataset:420 - init_type: exhaustive
19
+ 2026-03-12 11:20:45.703 | INFO | __main__:create_dataset:421 - output_dir: ./demos/ftd
20
+ 2026-03-12 11:20:45.703 | INFO | __main__:create_dataset:422 - seed: 97
21
+ 2026-03-12 11:20:45.704 | INFO | __main__:create_dataset:427 - Normalizing makespans by the sum of processing times with pfsp sum: 7.829251766204834
22
+ 2026-03-12 11:20:45.704 | INFO | __main__:create_dataset:434 - Exhaustive init_type: Number of samples: 5040
23
+ 2026-03-12 11:20:45.751 | INFO | __main__:create_dataset:545 - Minimum makespan: 0.5417570471763611
source/demos/ftd/metadata.json CHANGED
@@ -5,5 +5,5 @@
5
  "init_type": "exhaustive",
6
  "data_path": "./demos/ftd",
7
  "seed": 97,
8
- "date_time": "2026_03_10_13_57_19"
9
  }
 
5
  "init_type": "exhaustive",
6
  "data_path": "./demos/ftd",
7
  "seed": 97,
8
+ "date_time": "2026_03_12_11_20_45"
9
  }
source/demos/rs_artifacts/recover_schedules.log CHANGED
@@ -1,19 +1,19 @@
1
- 2026-03-10 16:36:22.938 | INFO | __main__:<module>:395 - Found better makespan!:
2
- 2026-03-10 16:36:22.939 | INFO | __main__:<module>:396 - recovered permutation: [1 4 3 2 6 5 0]
3
- 2026-03-10 16:36:22.939 | INFO | __main__:<module>:397 - actual makespan: 5.0353
4
- 2026-03-10 16:36:22.939 | INFO | __main__:<module>:398 - actual makespan normalized: 0.6431
5
- 2026-03-10 16:36:22.939 | INFO | __main__:<module>:399 - predicted makespan (normalized): 0.6463
6
- 2026-03-10 16:36:23.148 | INFO | __main__:<module>:395 - Found better makespan!:
7
- 2026-03-10 16:36:23.149 | INFO | __main__:<module>:396 - recovered permutation: [6 1 3 4 0 5 2]
8
- 2026-03-10 16:36:23.149 | INFO | __main__:<module>:397 - actual makespan: 4.6095
9
- 2026-03-10 16:36:23.149 | INFO | __main__:<module>:398 - actual makespan normalized: 0.5888
10
- 2026-03-10 16:36:23.149 | INFO | __main__:<module>:399 - predicted makespan (normalized): 0.5881
11
- 2026-03-10 16:36:35.134 | INFO | __main__:<module>:395 - Found better makespan!:
12
- 2026-03-10 16:36:35.134 | INFO | __main__:<module>:396 - recovered permutation: [4 6 3 1 0 5 2]
13
- 2026-03-10 16:36:35.134 | INFO | __main__:<module>:397 - actual makespan: 4.3768
14
- 2026-03-10 16:36:35.134 | INFO | __main__:<module>:398 - actual makespan normalized: 0.5590
15
- 2026-03-10 16:36:35.134 | INFO | __main__:<module>:399 - predicted makespan (normalized): 0.5755
16
- 2026-03-10 16:37:03.649 | INFO | __main__:<module>:411 - NEH makespan: 0.5418
17
- 2026-03-10 16:37:03.649 | INFO | __main__:<module>:412 - CDS makespan: 0.5418
18
- 2026-03-10 16:37:03.649 | INFO | __main__:<module>:413 - Palmer makespan: 0.5595
19
- 2026-03-10 16:37:03.649 | INFO | __main__:<module>:414 - Best makespan found by optimization: 0.5590
 
1
+ 2026-03-12 13:06:43.046 | INFO | __main__:<module>:395 - Found better makespan!:
2
+ 2026-03-12 13:06:43.046 | INFO | __main__:<module>:396 - recovered permutation: [1 4 3 2 6 5 0]
3
+ 2026-03-12 13:06:43.046 | INFO | __main__:<module>:397 - actual makespan: 5.0353
4
+ 2026-03-12 13:06:43.047 | INFO | __main__:<module>:398 - actual makespan normalized: 0.6431
5
+ 2026-03-12 13:06:43.047 | INFO | __main__:<module>:399 - predicted makespan (normalized): 0.6463
6
+ 2026-03-12 13:06:43.257 | INFO | __main__:<module>:395 - Found better makespan!:
7
+ 2026-03-12 13:06:43.257 | INFO | __main__:<module>:396 - recovered permutation: [6 1 3 4 0 5 2]
8
+ 2026-03-12 13:06:43.258 | INFO | __main__:<module>:397 - actual makespan: 4.6095
9
+ 2026-03-12 13:06:43.258 | INFO | __main__:<module>:398 - actual makespan normalized: 0.5888
10
+ 2026-03-12 13:06:43.258 | INFO | __main__:<module>:399 - predicted makespan (normalized): 0.5881
11
+ 2026-03-12 13:06:55.273 | INFO | __main__:<module>:395 - Found better makespan!:
12
+ 2026-03-12 13:06:55.273 | INFO | __main__:<module>:396 - recovered permutation: [4 6 3 1 0 5 2]
13
+ 2026-03-12 13:06:55.273 | INFO | __main__:<module>:397 - actual makespan: 4.3768
14
+ 2026-03-12 13:06:55.273 | INFO | __main__:<module>:398 - actual makespan normalized: 0.5590
15
+ 2026-03-12 13:06:55.273 | INFO | __main__:<module>:399 - predicted makespan (normalized): 0.5755
16
+ 2026-03-12 13:07:23.978 | INFO | __main__:<module>:411 - NEH makespan: 0.5418
17
+ 2026-03-12 13:07:23.978 | INFO | __main__:<module>:412 - CDS makespan: 0.5418
18
+ 2026-03-12 13:07:23.978 | INFO | __main__:<module>:413 - Palmer makespan: 0.5595
19
+ 2026-03-12 13:07:23.978 | INFO | __main__:<module>:414 - Best makespan found by optimization: 0.5590
source/demos/rs_artifacts/recover_schedules_pbar.log CHANGED
@@ -1 +1 @@
1
- Latent schedules optimization: 100%|██████████████████████████████████████████████████████████████████████████████████| 2000/2000 [00:41<00:00, 48.72it/s, total loss=1.18e+3, makespan=-0.0432, sinkhorn=592]
 
1
+ Latent schedules optimization: 100%|████████████████████████| 2000/2000 [00:41<00:00, 48.44it/s, total loss=1.18e+3, makespan=-0.0432, sinkhorn=592]
source/demos/train_artifacts/checkpoints/best_checkpoint.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4b1dad33e8a604d33778a2b09bfc3bb7ac8963a3ebb922bcc3e54a394bb6fe27
3
  size 1272039
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7f2a830b2b04168daacedd0b0c41ed034c98ec3be35611dd2a32c8cf6602353
3
  size 1272039
source/demos/train_artifacts/checkpoints/last_checkpoint.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8eedb1539052d19084f49a89460d478455c41828f7cc6711f0c12358a8e460d6
3
  size 1272039
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6500dac2061304bfd78de1d483743f4e57cc9830e68f30e148c9c9df5d483f62
3
  size 1272039
source/demos/train_artifacts/train.log CHANGED
@@ -1,41 +1,41 @@
1
- 2026-03-10 16:18:33.920 | INFO | __main__:__init__:292 - Loaded schedules from ./demos/ftd_processed/schedules_train.npy with shape (8568, 7)
2
- 2026-03-10 16:18:33.920 | INFO | __main__:__init__:297 - Loaded makespans from ./demos/ftd_processed/makespans_train.npy with shape (8568, 7)
3
- 2026-03-10 16:18:33.921 | INFO | __main__:<module>:325 - schedules.shape: torch.Size([64, 7])
4
- 2026-03-10 16:18:33.922 | INFO | __main__:<module>:326 - makespans.shape: torch.Size([64, 7])
5
- 2026-03-10 16:18:33.922 | INFO | __main__:<module>:392 - data_dir: ./demos/ftd_processed
6
- 2026-03-10 16:18:33.922 | INFO | __main__:<module>:393 - block_size: 7
7
- 2026-03-10 16:18:33.922 | INFO | __main__:<module>:394 - vocab_size: 7
8
- 2026-03-10 16:18:33.922 | INFO | __main__:<module>:395 - n_embd: 64
9
- 2026-03-10 16:18:33.922 | INFO | __main__:<module>:396 - n_head: 4
10
- 2026-03-10 16:18:33.922 | INFO | __main__:<module>:397 - n_layer: 2
11
- 2026-03-10 16:18:33.922 | INFO | __main__:<module>:398 - ff_width: 4
12
- 2026-03-10 16:18:33.922 | INFO | __main__:<module>:399 - train_batch_size: 64
13
- 2026-03-10 16:18:33.922 | INFO | __main__:<module>:400 - val_batch_size: 256
14
- 2026-03-10 16:18:33.922 | INFO | __main__:<module>:401 - dropout: 0.0
15
- 2026-03-10 16:18:33.922 | INFO | __main__:<module>:402 - nb_epochs: 5
16
- 2026-03-10 16:18:33.922 | INFO | __main__:<module>:403 - early_stopping_patience: 15
17
- 2026-03-10 16:18:33.922 | INFO | __main__:<module>:404 - nb_iters: 670
18
- 2026-03-10 16:18:33.922 | INFO | __main__:<module>:405 - checkpoint_interval: 33
19
- 2026-03-10 16:18:33.922 | INFO | __main__:<module>:406 - decay_lr: True
20
- 2026-03-10 16:18:33.922 | INFO | __main__:<module>:407 - lr_partitions_ratios: [0.66, None]
21
- 2026-03-10 16:18:33.922 | INFO | __main__:<module>:408 - lr_partitions_iters: [442, 228]
22
- 2026-03-10 16:18:33.922 | INFO | __main__:<module>:409 - init_lr: 0.0001
23
- 2026-03-10 16:18:33.922 | INFO | __main__:<module>:410 - max_lr: 0.001
24
- 2026-03-10 16:18:33.922 | INFO | __main__:<module>:411 - min_lr: 5e-05
25
- 2026-03-10 16:18:33.922 | INFO | __main__:<module>:412 - lr_warmup_iters_ratio: 0.1
26
- 2026-03-10 16:18:33.922 | INFO | __main__:<module>:413 - lr_decay_iters_ratio: 0.95
27
- 2026-03-10 16:18:33.922 | INFO | __main__:<module>:414 - beta1: 0.9
28
- 2026-03-10 16:18:33.922 | INFO | __main__:<module>:415 - beta2: 0.95
29
- 2026-03-10 16:18:33.922 | INFO | __main__:<module>:416 - weight_decay: 0.1
30
- 2026-03-10 16:18:33.922 | INFO | __main__:<module>:417 - grad_clip: 1.0
31
- 2026-03-10 16:18:33.922 | INFO | __main__:<module>:418 - compile: False
32
- 2026-03-10 16:18:33.922 | INFO | __main__:<module>:419 - compile_mode: default
33
- 2026-03-10 16:18:33.922 | INFO | __main__:<module>:420 - intermediate_schedules: True
34
- 2026-03-10 16:18:33.922 | INFO | __main__:<module>:421 - save_only_last_checkpoint: True
35
- 2026-03-10 16:18:34.011 | INFO | __main__:<module>:490 - The model has 100K trainable parameters
36
- 2026-03-10 16:18:34.011 | INFO | __main__:<module>:511 - num decayed parameter tensors: 40, with 99,712 parameters
37
- 2026-03-10 16:18:34.011 | INFO | __main__:<module>:512 - num non-decayed parameter tensors: 9, with 513 parameters
38
- 2026-03-10 16:18:34.011 | INFO | __main__:<module>:516 - using fused AdamW: True
39
- 2026-03-10 16:18:34.544 | INFO | __main__:__init__:292 - Loaded schedules from ./demos/ftd_processed/schedules_val.npy with shape (756, 7)
40
- 2026-03-10 16:18:34.544 | INFO | __main__:__init__:297 - Loaded makespans from ./demos/ftd_processed/makespans_val.npy with shape (756, 7)
41
- 2026-03-10 16:18:42.454 | INFO | __main__:<module>:703 - Best validation loss: 0.0003
 
1
+ 2026-03-12 13:06:19.589 | INFO | __main__:__init__:288 - Loaded schedules from ./demos/ftd_processed/schedules_train.npy with shape (8568, 7)
2
+ 2026-03-12 13:06:19.590 | INFO | __main__:__init__:293 - Loaded makespans from ./demos/ftd_processed/makespans_train.npy with shape (8568, 7)
3
+ 2026-03-12 13:06:19.605 | INFO | __main__:train:321 - schedules.shape: torch.Size([64, 7])
4
+ 2026-03-12 13:06:19.605 | INFO | __main__:train:322 - makespans.shape: torch.Size([64, 7])
5
+ 2026-03-12 13:06:19.605 | INFO | __main__:train:388 - data_dir: ./demos/ftd_processed
6
+ 2026-03-12 13:06:19.605 | INFO | __main__:train:389 - block_size: 7
7
+ 2026-03-12 13:06:19.605 | INFO | __main__:train:390 - vocab_size: 7
8
+ 2026-03-12 13:06:19.605 | INFO | __main__:train:391 - n_embd: 64
9
+ 2026-03-12 13:06:19.605 | INFO | __main__:train:392 - n_head: 4
10
+ 2026-03-12 13:06:19.605 | INFO | __main__:train:393 - n_layer: 2
11
+ 2026-03-12 13:06:19.605 | INFO | __main__:train:394 - ff_width: 4
12
+ 2026-03-12 13:06:19.605 | INFO | __main__:train:395 - train_batch_size: 64
13
+ 2026-03-12 13:06:19.605 | INFO | __main__:train:396 - val_batch_size: 256
14
+ 2026-03-12 13:06:19.605 | INFO | __main__:train:397 - dropout: 0.0
15
+ 2026-03-12 13:06:19.606 | INFO | __main__:train:398 - nb_epochs: 5
16
+ 2026-03-12 13:06:19.606 | INFO | __main__:train:399 - early_stopping_patience: 15
17
+ 2026-03-12 13:06:19.606 | INFO | __main__:train:400 - nb_iters: 670
18
+ 2026-03-12 13:06:19.606 | INFO | __main__:train:401 - checkpoint_interval: 33
19
+ 2026-03-12 13:06:19.606 | INFO | __main__:train:402 - decay_lr: True
20
+ 2026-03-12 13:06:19.606 | INFO | __main__:train:403 - lr_partitions_ratios: [0.66, None]
21
+ 2026-03-12 13:06:19.606 | INFO | __main__:train:404 - lr_partitions_iters: [442, 228]
22
+ 2026-03-12 13:06:19.606 | INFO | __main__:train:405 - init_lr: 0.0001
23
+ 2026-03-12 13:06:19.606 | INFO | __main__:train:406 - max_lr: 0.001
24
+ 2026-03-12 13:06:19.606 | INFO | __main__:train:407 - min_lr: 5e-05
25
+ 2026-03-12 13:06:19.606 | INFO | __main__:train:408 - lr_warmup_iters_ratio: 0.1
26
+ 2026-03-12 13:06:19.606 | INFO | __main__:train:409 - lr_decay_iters_ratio: 0.95
27
+ 2026-03-12 13:06:19.606 | INFO | __main__:train:410 - beta1: 0.9
28
+ 2026-03-12 13:06:19.606 | INFO | __main__:train:411 - beta2: 0.95
29
+ 2026-03-12 13:06:19.606 | INFO | __main__:train:412 - weight_decay: 0.1
30
+ 2026-03-12 13:06:19.606 | INFO | __main__:train:413 - grad_clip: 1.0
31
+ 2026-03-12 13:06:19.606 | INFO | __main__:train:414 - compile: False
32
+ 2026-03-12 13:06:19.606 | INFO | __main__:train:415 - compile_mode: default
33
+ 2026-03-12 13:06:19.606 | INFO | __main__:train:416 - intermediate_schedules: True
34
+ 2026-03-12 13:06:19.606 | INFO | __main__:train:417 - save_only_last_checkpoint: True
35
+ 2026-03-12 13:06:19.703 | INFO | __main__:train:485 - The model has 100K trainable parameters
36
+ 2026-03-12 13:06:19.703 | INFO | __main__:train:506 - num decayed parameter tensors: 40, with 99,712 parameters
37
+ 2026-03-12 13:06:19.703 | INFO | __main__:train:507 - num non-decayed parameter tensors: 9, with 513 parameters
38
+ 2026-03-12 13:06:19.704 | INFO | __main__:train:511 - using fused AdamW: True
39
+ 2026-03-12 13:06:20.458 | INFO | __main__:__init__:288 - Loaded schedules from ./demos/ftd_processed/schedules_val.npy with shape (756, 7)
40
+ 2026-03-12 13:06:20.458 | INFO | __main__:__init__:293 - Loaded makespans from ./demos/ftd_processed/makespans_val.npy with shape (756, 7)
41
+ 2026-03-12 13:06:28.461 | INFO | __main__:train:698 - Best validation loss: 0.0003
source/demos/train_artifacts/train_pbar_epoch.log CHANGED
@@ -1 +1 @@
1
- Epoch 5/5: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 134/134 [00:01<00:00, 88.48it/s]
 
1
+ Epoch 5/5: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 134/134 [00:01<00:00, 82.19it/s]
source/demos/train_artifacts/train_pbar_val.log CHANGED
@@ -1 +1 @@
1
- Validation 5.00: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 274.39it/s]
 
1
+ Validation 5.00: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 232.26it/s]
source/launch_create_dataset.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from create_dataset import create_dataset
2
+
3
+ params = [
4
+ {
5
+ "testing": False,
6
+ "nb_jobs": nb_jobs,
7
+ "nb_machines": nb_machines,
8
+ "time_min": 0,
9
+ "time_max": 1,
10
+ "nb_samples": 0,
11
+ "init_type": "exhaustive",
12
+ "output_dir": f"../datasets/exhaustive_{nb_jobs}_{nb_machines}",
13
+ "seed": 97,
14
+ "normalize_makespans": True,
15
+ "pfsp_instance": "",
16
+ "autoname_output_dir": False,
17
+ }
18
+ for nb_jobs in range(7, 11)
19
+ for nb_machines in range(2, 7)
20
+ ]
21
+
22
+ for param in params:
23
+ create_dataset(**param)
source/launch_process_dataset.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from process_dataset import process_dataset
2
+
3
+ params = [
4
+ {
5
+ "input_dir": f"../datasets/exhaustive_{nb_jobs}_{nb_machines}",
6
+ "output_dir": f"../datasets/exhaustive_{nb_jobs}_{nb_machines}/top_{k_eliminated}",
7
+ "train_ratio": 0.85,
8
+ "seed": 97,
9
+ "eliminate_top_k_makespans": k_eliminated,
10
+ "duplication_factor": 0.0,
11
+ }
12
+ for nb_jobs in range(7, 11)
13
+ for nb_machines in range(2, 7)
14
+ for k_eliminated in [0, 1, 2, 3, 4]
15
+ ]
16
+
17
+ for param in params:
18
+ process_dataset(**param)
source/launch_train.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from train import train
2
+
3
+ params = [
4
+ {
5
+ "testing": False,
6
+ "seed": 97,
7
+ "data_dir": f"../datasets/exhaustive_{nb_jobs}_{nb_machines}/top_{top_k}",
8
+ "n_embd": 64,
9
+ "n_head": 4,
10
+ "n_layer": 2,
11
+ "ff_width": 4,
12
+ "intermediate_schedules": True,
13
+ "train_batch_size": 128,
14
+ "val_batch_size": 256,
15
+ "nb_epochs": 5,
16
+ "early_stopping_patience": 15,
17
+ "dropout": 0.0,
18
+ "checkpoint_interval_ratio": 1.0,
19
+ "decay_lr": True,
20
+ "lr_partitions_ratios": [0.66],
21
+ "init_lr": 1e-4,
22
+ "max_lr": 1e-3,
23
+ "min_lr": 5e-5,
24
+ "lr_warmup_iters_ratio": 0.1,
25
+ "lr_decay_iters_ratio": 0.95,
26
+ "beta1": 0.9,
27
+ "beta2": 0.95,
28
+ "weight_decay": 1e-1,
29
+ "grad_clip": 1.0,
30
+ "compile": "",
31
+ "compile_mode": "default",
32
+ "save_only_last_checkpoint": False,
33
+ "output_dir": f"../datasets/exhaustive_{nb_jobs}_{nb_machines}/top_{top_k}/train_Sm_Wd1e-1",
34
+ }
35
+ for nb_jobs in [7, 8, 9]
36
+ for nb_machines in [2, 3, 4, 5, 6]
37
+ for top_k in [0, 1, 2, 3, 4]
38
+
39
+ ]
40
+
41
+ for param in params:
42
+ train(**param)
source/train.py CHANGED
@@ -173,52 +173,48 @@ class GPT(nn.Module):
173
  # ======
174
 
175
 
176
- if __name__ == "__main__":
177
-
178
- # parse arguments
179
- from argparse import ArgumentParser
180
- parser = ArgumentParser()
181
- parser.add_argument("--testing", type=bool, required=True)
182
- parser.add_argument("--seed", type=int, required=True)
183
- parser.add_argument("--data_dir", type=str, required=True)
184
- parser.add_argument("--n_embd", type=int, required=True)
185
- parser.add_argument("--n_head", type=int, required=True)
186
- parser.add_argument("--n_layer", type=int, required=True)
187
- parser.add_argument("--intermediate_schedules", type=bool, required=True)
188
- parser.add_argument("--dropout", type=float, required=True)
189
- parser.add_argument("--ff_width", type=int, required=True)
190
- parser.add_argument("--train_batch_size", type=int, required=True)
191
- parser.add_argument("--val_batch_size", type=int, required=True)
192
- parser.add_argument("--nb_epochs", type=int, required=True)
193
- parser.add_argument("--early_stopping_patience", type=int, required=True)
194
- parser.add_argument("--checkpoint_interval_ratio", type=float, required=True)
195
- parser.add_argument("--decay_lr", type=bool, required=True)
196
- parser.add_argument("--lr_partitions_ratios", type=lambda s: [float(item) for item in s.split(',')], help='Comma-separated list of floats that do not add up to 1 (e.g., 0.1,0.5,1)', required=True)
197
- parser.add_argument("--init_lr", type=float, required=True)
198
- parser.add_argument("--max_lr", type=float, required=True)
199
- parser.add_argument("--min_lr", type=float, required=True)
200
- parser.add_argument("--lr_warmup_iters_ratio", type=float, required=True)
201
- parser.add_argument("--lr_decay_iters_ratio", type=float, required=True)
202
- parser.add_argument("--beta1", type=float, required=True)
203
- parser.add_argument("--beta2", type=float, required=True)
204
- parser.add_argument("--weight_decay", type=float, required=True)
205
- parser.add_argument("--grad_clip", type=float, required=True)
206
- parser.add_argument("--compile", type=bool, required=True)
207
- parser.add_argument("--compile_mode", type=str, required=True)
208
- parser.add_argument("--save_only_last_checkpoint", type=bool, required=True)
209
- parser.add_argument("--output_dir", type=str, required=True)
210
- args = parser.parse_args()
211
-
212
- os.makedirs(args.output_dir, exist_ok=True)
213
 
214
  # check if experiment termination flag file exists
215
- if not args.testing:
216
- if os.path.exists(os.path.join(args.output_dir, ".terminated_phase1")):
217
  print("Phase 1 already terminated. Exiting...")
218
- exit()
219
  # ======
220
- if not os.path.exists(os.path.join(args.output_dir, "viz_train.ipynb")):
221
- shutil.copy("viz_train.ipynb", args.output_dir)
222
  # ======
223
  else:
224
 
@@ -236,12 +232,12 @@ if __name__ == "__main__":
236
  "viz_train.ipynb",
237
  ]
238
  for f in files_to_delete:
239
- f_path = os.path.join(args.output_dir, f)
240
  if os.path.exists(f_path): os.remove(f_path)
241
  # ======
242
- checkpoints_dir = os.path.join(args.output_dir, "checkpoints")
243
  if os.path.exists(checkpoints_dir): shutil.rmtree(checkpoints_dir)
244
- shutil.copy("viz_train.ipynb", args.output_dir)
245
  # ======
246
 
247
  # check if GPU is available
@@ -249,31 +245,31 @@ if __name__ == "__main__":
249
  device = "cuda"
250
 
251
  # setup logging
252
- loguru.logger.add(os.path.join(args.output_dir, "train.log"))
253
 
254
  # set random seeds
255
- torch.manual_seed(args.seed)
256
- random.seed(args.seed)
257
- np.random.seed(args.seed)
258
 
259
  # setup model architecture parameters
260
- with open(os.path.join(args.data_dir, "metadata.json"), "r") as f:
261
  metadata = json.load(f)
262
  block_size = metadata["nb_jobs"] # context window size
263
  vocab_size = metadata["nb_jobs"] # vocabulary size
264
- n_embd = args.n_embd # embedding dimension
265
- n_head = args.n_head # number of attention heads
266
  assert n_embd % n_head == 0
267
- n_layer = args.n_layer # number of transformer blocks
268
- intermediate_schedules = args.intermediate_schedules
269
- ff_width = args.ff_width
270
 
271
  # setup training parameters and utils
272
- train_batch_size = args.train_batch_size # batch size for training
273
- val_batch_size = args.val_batch_size # batch size for validation
274
- nb_epochs = args.nb_epochs # number of pseudo-epochs to train for
275
- early_stopping_patience = args.early_stopping_patience # number of epochs without improvement to trigger early stopping
276
- dropout = args.dropout
277
 
278
 
279
  class FlowshopDataset(torch.utils.data.Dataset):
@@ -313,7 +309,7 @@ if __name__ == "__main__":
313
  # ======
314
 
315
 
316
- train_dataset = FlowshopDataset(args.data_dir, split="train", load_in_memory=False)
317
  train_data_loader = torch.utils.data.DataLoader(
318
  train_dataset,
319
  batch_size=train_batch_size,
@@ -326,24 +322,24 @@ if __name__ == "__main__":
326
  loguru.logger.info(f"makespans.shape: {makespans.shape}")
327
  break
328
  nb_iters = nb_epochs * len(train_data_loader)
329
- checkpoint_interval = int(args.checkpoint_interval_ratio * len(train_data_loader))
330
- decay_lr = args.decay_lr
331
- lr_partitions_ratios = args.lr_partitions_ratios + [None]
332
  lr_partitions_iters = [int(r * nb_iters) for r in lr_partitions_ratios[:-1]]
333
  lr_partitions_iters = lr_partitions_iters + [nb_iters - sum(lr_partitions_iters)]
334
  assert sum(lr_partitions_iters) == nb_iters
335
- init_lr = args.init_lr #1e-4
336
- max_lr = args.max_lr #1e-3
337
- min_lr = args.min_lr #5*1e-5
338
- lr_warmup_iters_ratio = args.lr_warmup_iters_ratio #0.1
339
- lr_decay_iters_ratio = args.lr_decay_iters_ratio #0.95
340
- beta1 = args.beta1 # Adam beta1
341
- beta2 = args.beta2 # Adam beta2
342
- weight_decay = args.weight_decay # 1e-1 # weight decay
343
- grad_clip = args.grad_clip # 1.0 # gradient clipping value
344
- compile = args.compile
345
- compile_mode = args.compile_mode
346
- save_only_last_checkpoint = args.save_only_last_checkpoint
347
 
348
 
349
  def human_readable(num):
@@ -389,7 +385,7 @@ if __name__ == "__main__":
389
 
390
 
391
  # log parameters
392
- loguru.logger.info(f"data_dir: {args.data_dir}")
393
  loguru.logger.info(f"block_size: {block_size}")
394
  loguru.logger.info(f"vocab_size: {vocab_size}")
395
  loguru.logger.info(f"n_embd: {n_embd}")
@@ -421,9 +417,8 @@ if __name__ == "__main__":
421
  loguru.logger.info(f"save_only_last_checkpoint: {save_only_last_checkpoint}")
422
 
423
  # save parameters into a train_parameters.json
424
- import json
425
  train_params = {
426
- "data_dir": args.data_dir,
427
  "block_size": block_size,
428
  "vocab_size": vocab_size,
429
  "n_embd": n_embd,
@@ -454,11 +449,11 @@ if __name__ == "__main__":
454
  "intermediate_schedules": intermediate_schedules,
455
  "save_only_last_checkpoint": save_only_last_checkpoint,
456
  }
457
- with open(os.path.join(args.output_dir, "train_parameters.json"), "w") as f: json.dump(train_params, f, indent=4)
458
 
459
  # load the last checkpoint if it exists, otherwise initialize the training from scratch
460
  try:
461
- last_checkpoint = torch.load(os.path.join(args.output_dir, "checkpoints", "last_checkpoint.pth"))
462
  start_epoch = last_checkpoint["epoch"]
463
  start_epoch_iter = last_checkpoint["epoch_iter"] + 1
464
  model_state_dict = last_checkpoint["model_state_dict"]
@@ -467,7 +462,7 @@ if __name__ == "__main__":
467
  patience_counter = last_checkpoint["patience_counter"]
468
  improved_this_epoch = last_checkpoint["improved_this_epoch"]
469
  except FileNotFoundError:
470
- os.makedirs(os.path.join(args.output_dir, "checkpoints"), exist_ok=True)
471
  start_epoch = 0
472
  start_epoch_iter = 0
473
  model_state_dict = None
@@ -524,10 +519,10 @@ if __name__ == "__main__":
524
  torch.set_float32_matmul_precision("high")
525
 
526
  # initialize the np memmap array to save the batch losses
527
- batch_losses_path = os.path.join(args.output_dir, "batch_losses.npy")
528
- last_batch_loss_idx_path = os.path.join(args.output_dir, "last_batch_loss_idx.npy")
529
- val_losses_path = os.path.join(args.output_dir, "val_losses.npy")
530
- last_val_loss_idx_path = os.path.join(args.output_dir, "last_val_loss_idx.npy")
531
 
532
  try:
533
  batch_losses = np.lib.format.open_memmap(batch_losses_path, mode="r+", dtype=np.float32, shape=(nb_iters,))
@@ -546,7 +541,7 @@ if __name__ == "__main__":
546
 
547
  # create data_loader for validation
548
  val_data_loader = torch.utils.data.DataLoader(
549
- FlowshopDataset(args.data_dir, split="val", load_in_memory=True),
550
  batch_size=val_batch_size,
551
  shuffle=False,
552
  )
@@ -560,7 +555,7 @@ if __name__ == "__main__":
560
  # implement the logic to resume after failure
561
  ## create the generator, sampler, data loader
562
  generator = torch.Generator()
563
- generator.manual_seed(args.seed + epoch)
564
  train_sampler = torch.utils.data.RandomSampler(
565
  train_dataset,
566
  generator=generator
@@ -586,7 +581,7 @@ if __name__ == "__main__":
586
  initial=start_epoch_iter,
587
  desc=f"Epoch {epoch+1}/{nb_epochs}",
588
  )):
589
- with open(os.path.join(args.output_dir, "train_pbar_epoch.log"), "w") as f: f.write(str(pbar))
590
 
591
  # move the batch to the device
592
  schedules_batch = schedules_batch.to(device)
@@ -629,7 +624,7 @@ if __name__ == "__main__":
629
  val_data_loader,
630
  desc=f"Validation {epoch+(epoch_iter+1)/len(train_data_loader):.2f}",
631
  )):
632
- with open(os.path.join(args.output_dir, "train_pbar_val.log"), "w") as f: f.write(str(pbar2))
633
 
634
  # move the batch to the device
635
  schedules_batch = schedules_batch.to(device)
@@ -640,7 +635,7 @@ if __name__ == "__main__":
640
  makespans, loss = train_model(schedules_batch, makespans_batch)
641
  total_val_loss += loss.item() * schedules_batch.size(0)
642
  # ======
643
- with open(os.path.join(args.output_dir, "train_pbar_val.log"), "w") as f: f.write(str(pbar2))
644
 
645
  # compute the total validation loss (averaging over the dataset)
646
  total_val_loss /= len(val_data_loader.dataset)
@@ -669,22 +664,22 @@ if __name__ == "__main__":
669
  }
670
  torch.save(
671
  checkpoint,
672
- os.path.join(args.output_dir, "checkpoints", "last_checkpoint.pth")
673
  )
674
  if not save_only_last_checkpoint:
675
  torch.save(
676
  checkpoint,
677
- os.path.join(args.output_dir, "checkpoints", f"checkpoint_epoch_{epoch+(epoch_iter+1)/len(train_data_loader):.2f}.pth")
678
  )
679
  if best_val_loss == total_val_loss:
680
  torch.save(
681
  checkpoint,
682
- os.path.join(args.output_dir, "checkpoints", "best_checkpoint.pth")
683
  )
684
  # ======
685
  # ======
686
  # ======
687
- with open(os.path.join(args.output_dir, "train_pbar_epoch.log"), "w") as f: f.write(str(pbar))
688
 
689
  # set the start_epoch_iter to 0 for the next epoch
690
  start_epoch_iter = 0
@@ -703,7 +698,77 @@ if __name__ == "__main__":
703
  loguru.logger.info(f"Best validation loss: {best_val_loss:.4f}")
704
 
705
  # create experiment termination flag file
706
- with open(os.path.join(args.output_dir, ".terminated_phase1"), "w") as f:
707
  pass
708
  # ======
709
- # ======
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  # ======
174
 
175
 
176
+ def train(
177
+ testing: bool,
178
+ seed: int,
179
+ data_dir: str,
180
+ n_embd: int,
181
+ n_head: int,
182
+ n_layer: int,
183
+ intermediate_schedules: bool,
184
+ dropout: float,
185
+ ff_width: int,
186
+ train_batch_size: int,
187
+ val_batch_size: int,
188
+ nb_epochs: int,
189
+ early_stopping_patience: int,
190
+ checkpoint_interval_ratio: float,
191
+ decay_lr: bool,
192
+ lr_partitions_ratios: list[float],
193
+ init_lr: float,
194
+ max_lr: float,
195
+ min_lr: float,
196
+ lr_warmup_iters_ratio: float,
197
+ lr_decay_iters_ratio: float,
198
+ beta1: float,
199
+ beta2: float,
200
+ weight_decay: float,
201
+ grad_clip: float,
202
+ compile: bool,
203
+ compile_mode: str,
204
+ save_only_last_checkpoint: bool,
205
+ output_dir: str,
206
+ ):
207
+
208
+ os.makedirs(output_dir, exist_ok=True)
 
 
 
 
209
 
210
  # check if experiment termination flag file exists
211
+ if not testing:
212
+ if os.path.exists(os.path.join(output_dir, ".terminated_phase1")):
213
  print("Phase 1 already terminated. Exiting...")
214
+ return
215
  # ======
216
+ if not os.path.exists(os.path.join(output_dir, "viz_train.ipynb")):
217
+ shutil.copy("viz_train.ipynb", output_dir)
218
  # ======
219
  else:
220
 
 
232
  "viz_train.ipynb",
233
  ]
234
  for f in files_to_delete:
235
+ f_path = os.path.join(output_dir, f)
236
  if os.path.exists(f_path): os.remove(f_path)
237
  # ======
238
+ checkpoints_dir = os.path.join(output_dir, "checkpoints")
239
  if os.path.exists(checkpoints_dir): shutil.rmtree(checkpoints_dir)
240
+ shutil.copy("viz_train.ipynb", output_dir)
241
  # ======
242
 
243
  # check if GPU is available
 
245
  device = "cuda"
246
 
247
  # setup logging
248
+ loguru.logger.add(os.path.join(output_dir, "train.log"))
249
 
250
  # set random seeds
251
+ torch.manual_seed(seed)
252
+ random.seed(seed)
253
+ np.random.seed(seed)
254
 
255
  # setup model architecture parameters
256
+ with open(os.path.join(data_dir, "metadata.json"), "r") as f:
257
  metadata = json.load(f)
258
  block_size = metadata["nb_jobs"] # context window size
259
  vocab_size = metadata["nb_jobs"] # vocabulary size
260
+ n_embd = n_embd # embedding dimension
261
+ n_head = n_head # number of attention heads
262
  assert n_embd % n_head == 0
263
+ n_layer = n_layer # number of transformer blocks
264
+ intermediate_schedules = intermediate_schedules
265
+ ff_width = ff_width
266
 
267
  # setup training parameters and utils
268
+ train_batch_size = train_batch_size # batch size for training
269
+ val_batch_size = val_batch_size # batch size for validation
270
+ nb_epochs = nb_epochs # number of pseudo-epochs to train for
271
+ early_stopping_patience = early_stopping_patience # number of epochs without improvement to trigger early stopping
272
+ dropout = dropout
273
 
274
 
275
  class FlowshopDataset(torch.utils.data.Dataset):
 
309
  # ======
310
 
311
 
312
+ train_dataset = FlowshopDataset(data_dir, split="train", load_in_memory=False)
313
  train_data_loader = torch.utils.data.DataLoader(
314
  train_dataset,
315
  batch_size=train_batch_size,
 
322
  loguru.logger.info(f"makespans.shape: {makespans.shape}")
323
  break
324
  nb_iters = nb_epochs * len(train_data_loader)
325
+ checkpoint_interval = int(checkpoint_interval_ratio * len(train_data_loader))
326
+ decay_lr = decay_lr
327
+ lr_partitions_ratios = lr_partitions_ratios + [None]
328
  lr_partitions_iters = [int(r * nb_iters) for r in lr_partitions_ratios[:-1]]
329
  lr_partitions_iters = lr_partitions_iters + [nb_iters - sum(lr_partitions_iters)]
330
  assert sum(lr_partitions_iters) == nb_iters
331
+ init_lr = init_lr #1e-4
332
+ max_lr = max_lr #1e-3
333
+ min_lr = min_lr #5*1e-5
334
+ lr_warmup_iters_ratio = lr_warmup_iters_ratio #0.1
335
+ lr_decay_iters_ratio = lr_decay_iters_ratio #0.95
336
+ beta1 = beta1 # Adam beta1
337
+ beta2 = beta2 # Adam beta2
338
+ weight_decay = weight_decay # 1e-1 # weight decay
339
+ grad_clip = grad_clip # 1.0 # gradient clipping value
340
+ compile = compile
341
+ compile_mode = compile_mode
342
+ save_only_last_checkpoint = save_only_last_checkpoint
343
 
344
 
345
  def human_readable(num):
 
385
 
386
 
387
  # log parameters
388
+ loguru.logger.info(f"data_dir: {data_dir}")
389
  loguru.logger.info(f"block_size: {block_size}")
390
  loguru.logger.info(f"vocab_size: {vocab_size}")
391
  loguru.logger.info(f"n_embd: {n_embd}")
 
417
  loguru.logger.info(f"save_only_last_checkpoint: {save_only_last_checkpoint}")
418
 
419
  # save parameters into a train_parameters.json
 
420
  train_params = {
421
+ "data_dir": data_dir,
422
  "block_size": block_size,
423
  "vocab_size": vocab_size,
424
  "n_embd": n_embd,
 
449
  "intermediate_schedules": intermediate_schedules,
450
  "save_only_last_checkpoint": save_only_last_checkpoint,
451
  }
452
+ with open(os.path.join(output_dir, "train_parameters.json"), "w") as f: json.dump(train_params, f, indent=4)
453
 
454
  # load the last checkpoint if it exists, otherwise initialize the training from scratch
455
  try:
456
+ last_checkpoint = torch.load(os.path.join(output_dir, "checkpoints", "last_checkpoint.pth"))
457
  start_epoch = last_checkpoint["epoch"]
458
  start_epoch_iter = last_checkpoint["epoch_iter"] + 1
459
  model_state_dict = last_checkpoint["model_state_dict"]
 
462
  patience_counter = last_checkpoint["patience_counter"]
463
  improved_this_epoch = last_checkpoint["improved_this_epoch"]
464
  except FileNotFoundError:
465
+ os.makedirs(os.path.join(output_dir, "checkpoints"), exist_ok=True)
466
  start_epoch = 0
467
  start_epoch_iter = 0
468
  model_state_dict = None
 
519
  torch.set_float32_matmul_precision("high")
520
 
521
  # initialize the np memmap array to save the batch losses
522
+ batch_losses_path = os.path.join(output_dir, "batch_losses.npy")
523
+ last_batch_loss_idx_path = os.path.join(output_dir, "last_batch_loss_idx.npy")
524
+ val_losses_path = os.path.join(output_dir, "val_losses.npy")
525
+ last_val_loss_idx_path = os.path.join(output_dir, "last_val_loss_idx.npy")
526
 
527
  try:
528
  batch_losses = np.lib.format.open_memmap(batch_losses_path, mode="r+", dtype=np.float32, shape=(nb_iters,))
 
541
 
542
  # create data_loader for validation
543
  val_data_loader = torch.utils.data.DataLoader(
544
+ FlowshopDataset(data_dir, split="val", load_in_memory=True),
545
  batch_size=val_batch_size,
546
  shuffle=False,
547
  )
 
555
  # implement the logic to resume after failure
556
  ## create the generator, sampler, data loader
557
  generator = torch.Generator()
558
+ generator.manual_seed(seed + epoch)
559
  train_sampler = torch.utils.data.RandomSampler(
560
  train_dataset,
561
  generator=generator
 
581
  initial=start_epoch_iter,
582
  desc=f"Epoch {epoch+1}/{nb_epochs}",
583
  )):
584
+ with open(os.path.join(output_dir, "train_pbar_epoch.log"), "w") as f: f.write(str(pbar))
585
 
586
  # move the batch to the device
587
  schedules_batch = schedules_batch.to(device)
 
624
  val_data_loader,
625
  desc=f"Validation {epoch+(epoch_iter+1)/len(train_data_loader):.2f}",
626
  )):
627
+ with open(os.path.join(output_dir, "train_pbar_val.log"), "w") as f: f.write(str(pbar2))
628
 
629
  # move the batch to the device
630
  schedules_batch = schedules_batch.to(device)
 
635
  makespans, loss = train_model(schedules_batch, makespans_batch)
636
  total_val_loss += loss.item() * schedules_batch.size(0)
637
  # ======
638
+ with open(os.path.join(output_dir, "train_pbar_val.log"), "w") as f: f.write(str(pbar2))
639
 
640
  # compute the total validation loss (averaging over the dataset)
641
  total_val_loss /= len(val_data_loader.dataset)
 
664
  }
665
  torch.save(
666
  checkpoint,
667
+ os.path.join(output_dir, "checkpoints", "last_checkpoint.pth")
668
  )
669
  if not save_only_last_checkpoint:
670
  torch.save(
671
  checkpoint,
672
+ os.path.join(output_dir, "checkpoints", f"checkpoint_epoch_{epoch+(epoch_iter+1)/len(train_data_loader):.2f}.pth")
673
  )
674
  if best_val_loss == total_val_loss:
675
  torch.save(
676
  checkpoint,
677
+ os.path.join(output_dir, "checkpoints", "best_checkpoint.pth")
678
  )
679
  # ======
680
  # ======
681
  # ======
682
+ with open(os.path.join(output_dir, "train_pbar_epoch.log"), "w") as f: f.write(str(pbar))
683
 
684
  # set the start_epoch_iter to 0 for the next epoch
685
  start_epoch_iter = 0
 
698
  loguru.logger.info(f"Best validation loss: {best_val_loss:.4f}")
699
 
700
  # create experiment termination flag file
701
+ with open(os.path.join(output_dir, ".terminated_phase1"), "w") as f:
702
  pass
703
  # ======
704
+ # ======
705
+
706
+
707
+ if __name__ == "__main__":
708
+
709
+ # parse arguments
710
+ from argparse import ArgumentParser
711
+ parser = ArgumentParser()
712
+ parser.add_argument("--testing", type=bool, required=True)
713
+ parser.add_argument("--seed", type=int, required=True)
714
+ parser.add_argument("--data_dir", type=str, required=True)
715
+ parser.add_argument("--n_embd", type=int, required=True)
716
+ parser.add_argument("--n_head", type=int, required=True)
717
+ parser.add_argument("--n_layer", type=int, required=True)
718
+ parser.add_argument("--intermediate_schedules", type=bool, required=True)
719
+ parser.add_argument("--dropout", type=float, required=True)
720
+ parser.add_argument("--ff_width", type=int, required=True)
721
+ parser.add_argument("--train_batch_size", type=int, required=True)
722
+ parser.add_argument("--val_batch_size", type=int, required=True)
723
+ parser.add_argument("--nb_epochs", type=int, required=True)
724
+ parser.add_argument("--early_stopping_patience", type=int, required=True)
725
+ parser.add_argument("--checkpoint_interval_ratio", type=float, required=True)
726
+ parser.add_argument("--decay_lr", type=bool, required=True)
727
+ parser.add_argument("--lr_partitions_ratios", type=lambda s: [float(item) for item in s.split(',')], help='Comma-separated list of floats that do not add up to 1 (e.g., 0.1,0.5,1)', required=True)
728
+ parser.add_argument("--init_lr", type=float, required=True)
729
+ parser.add_argument("--max_lr", type=float, required=True)
730
+ parser.add_argument("--min_lr", type=float, required=True)
731
+ parser.add_argument("--lr_warmup_iters_ratio", type=float, required=True)
732
+ parser.add_argument("--lr_decay_iters_ratio", type=float, required=True)
733
+ parser.add_argument("--beta1", type=float, required=True)
734
+ parser.add_argument("--beta2", type=float, required=True)
735
+ parser.add_argument("--weight_decay", type=float, required=True)
736
+ parser.add_argument("--grad_clip", type=float, required=True)
737
+ parser.add_argument("--compile", type=bool, required=True)
738
+ parser.add_argument("--compile_mode", type=str, required=True)
739
+ parser.add_argument("--save_only_last_checkpoint", type=bool, required=True)
740
+ parser.add_argument("--output_dir", type=str, required=True)
741
+ args = parser.parse_args()
742
+
743
+
744
+ train(
745
+ testing=args.testing,
746
+ seed=args.seed,
747
+ data_dir=args.data_dir,
748
+ n_embd=args.n_embd,
749
+ n_head=args.n_head,
750
+ n_layer=args.n_layer,
751
+ intermediate_schedules=args.intermediate_schedules,
752
+ dropout=args.dropout,
753
+ ff_width=args.ff_width,
754
+ train_batch_size=args.train_batch_size,
755
+ val_batch_size=args.val_batch_size,
756
+ nb_epochs=args.nb_epochs,
757
+ early_stopping_patience=args.early_stopping_patience,
758
+ checkpoint_interval_ratio=args.checkpoint_interval_ratio,
759
+ decay_lr=args.decay_lr,
760
+ lr_partitions_ratios=args.lr_partitions_ratios,
761
+ init_lr=args.init_lr,
762
+ max_lr=args.max_lr,
763
+ min_lr=args.min_lr,
764
+ lr_warmup_iters_ratio=args.lr_warmup_iters_ratio,
765
+ lr_decay_iters_ratio=args.lr_decay_iters_ratio,
766
+ beta1=args.beta1,
767
+ beta2=args.beta2,
768
+ weight_decay=args.weight_decay,
769
+ grad_clip=args.grad_clip,
770
+ compile=args.compile,
771
+ compile_mode=args.compile_mode,
772
+ save_only_last_checkpoint=args.save_only_last_checkpoint,
773
+ output_dir=args.output_dir,
774
+ )