Upload 39 files
Browse files- .gitattributes +3 -0
- data/benchmarks/taillard/ft06.json +1 -0
- data/benchmarks/taillard/ft10.json +1 -0
- data/benchmarks/taillard/ta01.json +1 -0
- data/benchmarks/taillard/ta02.json +1 -0
- data/benchmarks/taillard/ta03.json +1 -0
- data/raw/priority_dataset.csv +0 -0
- data/raw/priority_dataset_augmented.csv +3 -0
- data/raw/selector_dataset.csv +0 -0
- data/real/calibrated_params.json +20 -0
- data/real/olist_order_items_dataset.csv +3 -0
- data/real/olist_orders_dataset.csv +3 -0
- data/real/olist_products_dataset.csv +0 -0
- scripts/calibrate_real_data.py +770 -0
- scripts/foolproof_retrain.py +476 -0
- scripts/hf_runner.py +62 -0
- scripts/run_pipeline.py +139 -0
- scripts/run_preset_benchmark.py +220 -0
- src/__init__.py +84 -0
- src/__pycache__/__init__.cpython-312.pyc +0 -0
- src/__pycache__/data_generator.cpython-312.pyc +0 -0
- src/__pycache__/evaluator.cpython-312.pyc +0 -0
- src/__pycache__/features.cpython-312.pyc +0 -0
- src/__pycache__/heuristics.cpython-312.pyc +0 -0
- src/__pycache__/hybrid_scheduler.cpython-312.pyc +0 -0
- src/__pycache__/presets.cpython-312.pyc +0 -0
- src/__pycache__/references.cpython-312.pyc +0 -0
- src/__pycache__/simulator.cpython-312.pyc +0 -0
- src/__pycache__/train_priority.cpython-312.pyc +0 -0
- src/__pycache__/train_selector.cpython-312.pyc +0 -0
- src/data_generator.py +425 -0
- src/evaluator.py +899 -0
- src/features.py +508 -0
- src/heuristics.py +197 -0
- src/hybrid_scheduler.py +865 -0
- src/presets.py +399 -0
- src/references.py +179 -0
- src/simulator.py +1302 -0
- src/train_priority.py +139 -0
- src/train_selector.py +316 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
data/raw/priority_dataset_augmented.csv filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
data/real/olist_order_items_dataset.csv filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
data/real/olist_orders_dataset.csv filter=lfs diff=lfs merge=lfs -text
|
data/benchmarks/taillard/ft06.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"name": "ft06", "n_jobs": 6, "n_machines": 6, "processing_times": [[47, 51, 75, 95, 4, 15], [82, 94, 25, 31, 87, 42], [28, 82, 26, 41, 64, 55], [9, 3, 86, 75, 83, 54], [81, 33, 45, 79, 13, 31], [13, 45, 97, 14, 38, 40]], "machine_order": [[3, 1, 2, 0, 4, 5], [5, 3, 2, 4, 0, 1], [0, 5, 3, 1, 4, 2], [0, 5, 4, 1, 2, 3], [3, 2, 1, 4, 5, 0], [0, 1, 4, 5, 3, 2]]}
|
data/benchmarks/taillard/ft10.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"name": "ft10", "n_jobs": 10, "n_machines": 10, "processing_times": [[83, 26, 11, 30, 41, 81, 45, 10, 34, 60], [81, 73, 99, 19, 88, 6, 56, 28, 20, 66], [31, 56, 26, 15, 75, 43, 68, 67, 94, 42], [22, 63, 93, 96, 86, 68, 38, 39, 4, 19], [33, 35, 58, 51, 69, 89, 87, 77, 97, 32], [90, 92, 23, 47, 57, 69, 70, 11, 48, 11], [95, 20, 45, 88, 52, 68, 50, 85, 59, 64], [45, 41, 59, 52, 78, 59, 47, 86, 22, 44], [49, 89, 44, 61, 8, 83, 44, 50, 10, 69], [63, 34, 77, 52, 6, 22, 26, 10, 50, 4]], "machine_order": [[1, 7, 6, 4, 9, 0, 3, 8, 2, 5], [2, 3, 0, 6, 1, 8, 7, 9, 5, 4], [9, 7, 8, 5, 0, 4, 3, 6, 1, 2], [2, 0, 5, 8, 7, 4, 3, 1, 6, 9], [2, 0, 6, 5, 3, 8, 7, 4, 9, 1], [2, 0, 3, 9, 5, 8, 1, 7, 6, 4], [3, 2, 5, 7, 8, 4, 0, 9, 6, 1], [2, 1, 6, 7, 8, 9, 4, 5, 0, 3], [1, 6, 0, 4, 5, 7, 2, 3, 8, 9], [4, 6, 8, 9, 5, 0, 3, 7, 1, 2]]}
|
data/benchmarks/taillard/ta01.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"name": "ta01", "n_jobs": 15, "n_machines": 15, "processing_times": [[9, 77, 65, 44, 43, 86, 9, 70, 20, 10, 53, 97, 73, 76, 72], [78, 51, 13, 84, 45, 50, 37, 19, 92, 78, 64, 40, 82, 54, 44], [45, 23, 10, 55, 88, 7, 85, 82, 28, 63, 17, 76, 70, 36, 7], [97, 45, 89, 68, 78, 76, 20, 37, 47, 50, 5, 55, 16, 74, 68], [92, 74, 37, 96, 41, 33, 90, 37, 8, 47, 79, 19, 46, 13, 68], [48, 33, 23, 56, 67, 94, 44, 16, 83, 63, 70, 10, 31, 77, 83], [44, 80, 84, 39, 89, 29, 24, 68, 64, 14, 83, 20, 80, 1, 79], [78, 78, 66, 47, 70, 28, 78, 56, 46, 51, 57, 4, 14, 25, 12], [44, 67, 65, 47, 85, 56, 8, 76, 57, 63, 56, 55, 9, 56, 79], [31, 60, 4, 35, 44, 98, 22, 28, 41, 99, 85, 4, 24, 82, 6], [85, 28, 91, 30, 43, 66, 13, 56, 50, 78, 99, 66, 41, 41, 42], [81, 32, 17, 34, 3, 11, 9, 77, 72, 69, 46, 71, 16, 90, 50], [93, 16, 50, 69, 50, 45, 17, 38, 24, 30, 68, 63, 61, 36, 96], [9, 34, 12, 34, 96, 37, 90, 50, 70, 46, 27, 76, 96, 27, 78], [26, 71, 79, 45, 73, 27, 8, 10, 45, 90, 13, 46, 70, 21, 72]], "machine_order": [[0, 6, 8, 3, 13, 10, 11, 2, 5, 7, 9, 12, 1, 4, 14], [6, 10, 7, 5, 3, 4, 1, 12, 8, 0, 2, 9, 11, 14, 13], [7, 11, 3, 12, 10, 4, 0, 9, 8, 1, 6, 5, 14, 2, 13], [10, 3, 8, 6, 9, 5, 2, 13, 12, 7, 1, 0, 11, 14, 4], [12, 3, 9, 8, 4, 13, 6, 10, 1, 5, 0, 2, 14, 11, 7], [0, 8, 7, 2, 4, 13, 11, 3, 6, 10, 1, 12, 9, 5, 14], [12, 2, 6, 4, 0, 10, 7, 1, 9, 14, 11, 3, 5, 13, 8], [3, 0, 12, 11, 6, 4, 13, 10, 5, 9, 14, 2, 7, 8, 1], [3, 0, 9, 13, 8, 14, 12, 2, 7, 11, 5, 4, 10, 6, 1], [14, 6, 8, 12, 13, 5, 9, 11, 7, 1, 4, 3, 2, 10, 0], [4, 7, 12, 3, 14, 8, 6, 0, 1, 10, 13, 2, 5, 11, 9], [9, 8, 2, 11, 12, 6, 10, 7, 5, 3, 13, 0, 14, 4, 1], [13, 11, 1, 7, 0, 14, 2, 3, 9, 4, 6, 8, 10, 12, 5], [6, 4, 0, 1, 13, 7, 8, 12, 5, 11, 2, 10, 9, 3, 14], [13, 5, 9, 0, 4, 8, 3, 11, 12, 1, 2, 10, 6, 14, 7]]}
|
data/benchmarks/taillard/ta02.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"name": "ta02", "n_jobs": 15, "n_machines": 15, "processing_times": [[2, 68, 59, 6, 90, 22, 26, 19, 34, 18, 35, 81, 45, 92, 45], [28, 79, 82, 86, 89, 3, 51, 27, 25, 24, 82, 79, 22, 41, 74], [15, 63, 44, 92, 74, 23, 83, 80, 22, 52, 79, 23, 25, 17, 1], [50, 2, 58, 42, 19, 36, 2, 15, 47, 8, 73, 26, 91, 17, 62], [52, 91, 62, 86, 46, 22, 19, 86, 21, 73, 53, 28, 41, 79, 45], [86, 94, 30, 53, 53, 54, 8, 37, 58, 91, 24, 64, 76, 92, 18], [67, 31, 44, 2, 72, 4, 38, 50, 67, 47, 84, 13, 26, 26, 38], [1, 43, 38, 58, 58, 28, 43, 50, 83, 8, 62, 8, 27, 44, 81], [34, 50, 58, 76, 40, 57, 84, 44, 93, 40, 43, 3, 53, 47, 42], [62, 55, 94, 41, 44, 81, 49, 25, 52, 51, 41, 53, 58, 97, 7], [91, 49, 38, 61, 12, 74, 10, 43, 20, 30, 47, 1, 88, 75, 95], [8, 18, 49, 85, 31, 30, 84, 1, 95, 19, 32, 33, 89, 82, 34], [34, 81, 29, 80, 81, 65, 74, 23, 56, 14, 6, 43, 30, 16, 53], [87, 25, 18, 86, 3, 59, 56, 47, 43, 3, 86, 7, 16, 88, 36], [46, 64, 56, 93, 93, 26, 76, 69, 25, 15, 81, 74, 38, 30, 69]], "machine_order": [[0, 2, 14, 12, 13, 10, 3, 5, 6, 11, 9, 8, 7, 1, 4], [13, 5, 8, 14, 6, 4, 0, 10, 12, 7, 11, 3, 1, 9, 2], [10, 4, 7, 3, 12, 9, 8, 14, 11, 2, 6, 5, 0, 1, 13], [7, 3, 5, 14, 10, 12, 13, 1, 9, 6, 11, 2, 4, 0, 8], [8, 4, 1, 5, 0, 2, 3, 13, 11, 9, 12, 14, 10, 7, 6], [6, 12, 1, 11, 2, 9, 3, 5, 7, 13, 8, 4, 10, 14, 0], [6, 2, 3, 12, 7, 5, 1, 8, 14, 10, 9, 4, 13, 11, 0], [6, 0, 1, 8, 4, 2, 5, 11, 3, 12, 14, 13, 7, 10, 9], [3, 9, 12, 5, 1, 14, 11, 4, 2, 7, 0, 10, 6, 13, 8], [7, 0, 5, 14, 9, 10, 13, 3, 4, 11, 2, 1, 12, 8, 6], [0, 12, 1, 3, 2, 5, 10, 13, 8, 9, 11, 6, 14, 7, 4], [4, 12, 14, 11, 10, 0, 5, 7, 6, 8, 2, 13, 9, 1, 3], [4, 7, 12, 1, 8, 10, 0, 9, 3, 6, 13, 5, 14, 2, 11], [5, 13, 10, 0, 11, 14, 7, 12, 9, 4, 3, 6, 2, 8, 1], [8, 9, 0, 1, 6, 2, 4, 14, 3, 7, 13, 11, 5, 12, 10]]}
|
data/benchmarks/taillard/ta03.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"name": "ta03", "n_jobs": 15, "n_machines": 15, "processing_times": [[35, 47, 47, 45, 5, 13, 82, 75, 62, 97, 7, 9, 33, 31, 68], [79, 69, 40, 52, 98, 26, 24, 50, 56, 49, 53, 44, 32, 47, 24], [24, 55, 85, 34, 69, 80, 17, 37, 10, 18, 98, 77, 6, 68, 56], [29, 22, 1, 18, 92, 93, 63, 77, 51, 98, 51, 11, 38, 53, 49], [85, 64, 85, 96, 5, 36, 71, 76, 10, 74, 27, 38, 63, 71, 65], [3, 97, 95, 30, 76, 94, 26, 11, 45, 93, 65, 86, 28, 37, 40], [67, 74, 71, 66, 27, 52, 62, 27, 23, 15, 92, 63, 64, 88, 25], [28, 38, 13, 77, 22, 25, 70, 23, 69, 24, 26, 3, 13, 27, 94], [19, 12, 88, 95, 61, 89, 31, 93, 82, 49, 40, 6, 92, 15, 94], [80, 64, 33, 8, 78, 70, 40, 56, 26, 10, 74, 89, 71, 26, 88], [36, 67, 7, 40, 46, 11, 15, 6, 60, 24, 98, 58, 75, 88, 71], [65, 90, 54, 56, 20, 72, 98, 37, 44, 30, 41, 84, 19, 53, 89], [41, 66, 94, 12, 19, 24, 52, 12, 97, 16, 19, 20, 78, 38, 14], [22, 26, 91, 50, 76, 2, 93, 81, 20, 74, 81, 24, 42, 37, 93], [22, 27, 59, 61, 73, 18, 60, 3, 45, 52, 17, 11, 19, 39, 34]], "machine_order": [[14, 11, 4, 2, 7, 10, 5, 8, 6, 9, 13, 0, 12, 1, 3], [5, 7, 0, 12, 1, 10, 9, 2, 4, 3, 6, 13, 14, 8, 11], [7, 0, 4, 13, 3, 2, 8, 1, 6, 12, 5, 14, 9, 10, 11], [3, 8, 13, 2, 11, 14, 7, 1, 10, 12, 4, 5, 9, 6, 0], [0, 6, 7, 3, 10, 5, 1, 9, 8, 11, 2, 12, 4, 13, 14], [8, 12, 2, 11, 5, 14, 13, 9, 3, 7, 1, 10, 0, 6, 4], [8, 2, 13, 5, 7, 4, 3, 12, 14, 11, 0, 10, 1, 6, 9], [3, 2, 13, 5, 8, 12, 6, 1, 9, 7, 11, 14, 10, 0, 4], [0, 8, 14, 2, 7, 1, 11, 13, 12, 3, 5, 10, 9, 4, 6], [3, 6, 8, 0, 9, 11, 4, 12, 1, 5, 2, 10, 13, 7, 14], [1, 7, 12, 5, 0, 2, 4, 3, 9, 14, 6, 10, 13, 8, 11], [6, 14, 7, 4, 13, 12, 1, 9, 0, 3, 11, 5, 8, 2, 10], [7, 11, 6, 3, 5, 9, 8, 13, 2, 10, 12, 4, 1, 0, 14], [5, 13, 7, 6, 1, 12, 10, 11, 9, 0, 14, 3, 2, 8, 4], [9, 2, 13, 6, 5, 8, 1, 11, 3, 10, 0, 12, 7, 14, 4]]}
|
data/raw/priority_dataset.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/raw/priority_dataset_augmented.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:61f28925ede345853a95d07285fe4076563d38f734ae9a552217c89234400b83
|
| 3 |
+
size 29492802
|
data/raw/selector_dataset.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/real/calibrated_params.json
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"source": "calibrated_from_olist_real_data",
|
| 3 |
+
"arrival_rate_per_min": 0.5,
|
| 4 |
+
"due_date_tightness": 1.5,
|
| 5 |
+
"job_type_frequencies": {
|
| 6 |
+
"A": 0.21,
|
| 7 |
+
"B": 0.28,
|
| 8 |
+
"C": 0.223,
|
| 9 |
+
"D": 0.187,
|
| 10 |
+
"E": 0.1
|
| 11 |
+
},
|
| 12 |
+
"sla_breach_rate_baseline_target": 0.08112366538820359,
|
| 13 |
+
"raw_olist_stats": {
|
| 14 |
+
"orders_per_day_mean": 157.6437908496732,
|
| 15 |
+
"orders_per_600min_shift": 98.52736928104575,
|
| 16 |
+
"sla_window_median_days": 23.23087962962963,
|
| 17 |
+
"cycle_time_median_days": 10.217476851851853,
|
| 18 |
+
"sla_breach_rate": 0.08112366538820359
|
| 19 |
+
}
|
| 20 |
+
}
|
data/real/olist_order_items_dataset.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4f6abdbbc94036d0df4a76fa0520c072e31a40119d70f7f370fba1e2285d2bcb
|
| 3 |
+
size 15007623
|
data/real/olist_orders_dataset.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8df58ef3d2d7e9944010f7beecd9b75367f5588ec6e3c91cec19ae3345ef9ecf
|
| 3 |
+
size 17654914
|
data/real/olist_products_dataset.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
scripts/calibrate_real_data.py
ADDED
|
@@ -0,0 +1,770 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
scripts/calibrate_real_data.py — Real-Data Calibration for DAHS_2
|
| 4 |
+
|
| 5 |
+
Uses three real datasets to ground simulator parameters:
|
| 6 |
+
1. Olist Brazilian E-Commerce (99,441 orders) — arrival rates, SLA windows, tardiness
|
| 7 |
+
2. E-Commerce Shipping (Prachi13 structure, synthetic-real hybrid) — zone/breach structure
|
| 8 |
+
3. Taillard JSP benchmarks — heuristic validation vs published bounds
|
| 9 |
+
|
| 10 |
+
Outputs:
|
| 11 |
+
- results/calibration/arrival_rate_analysis.png
|
| 12 |
+
- results/calibration/sla_window_analysis.png
|
| 13 |
+
- results/calibration/tardiness_distribution.png
|
| 14 |
+
- results/calibration/taillard_heuristic_comparison.png
|
| 15 |
+
- results/calibration/calibration_report.json
|
| 16 |
+
- data/real/calibrated_params.json (updated simulator params)
|
| 17 |
+
|
| 18 |
+
Usage:
|
| 19 |
+
python scripts/calibrate_real_data.py
|
| 20 |
+
"""
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
import json
|
| 24 |
+
import logging
|
| 25 |
+
import sys
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
|
| 28 |
+
import matplotlib
|
| 29 |
+
matplotlib.use("Agg")
|
| 30 |
+
import matplotlib.pyplot as plt
|
| 31 |
+
import numpy as np
|
| 32 |
+
import pandas as pd
|
| 33 |
+
from scipy import stats
|
| 34 |
+
|
| 35 |
+
ROOT = Path(__file__).parent.parent
|
| 36 |
+
sys.path.insert(0, str(ROOT))
|
| 37 |
+
|
| 38 |
+
# Force UTF-8 output
|
| 39 |
+
for _s in ("stdout", "stderr"):
|
| 40 |
+
try:
|
| 41 |
+
getattr(sys, _s).reconfigure(encoding="utf-8", errors="replace")
|
| 42 |
+
except Exception:
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
| 46 |
+
logger = logging.getLogger(__name__)
|
| 47 |
+
|
| 48 |
+
REAL_DIR = ROOT / "data" / "real"
|
| 49 |
+
BENCH_DIR = ROOT / "data" / "benchmarks" / "taillard"
|
| 50 |
+
RESULTS_DIR = ROOT / "results" / "calibration"
|
| 51 |
+
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# =============================================================================
|
| 55 |
+
# PART 1: Olist Arrival Rate Analysis
|
| 56 |
+
# =============================================================================
|
| 57 |
+
|
| 58 |
+
def analyze_olist_arrivals(orders_path: Path) -> dict:
|
| 59 |
+
"""Extract hourly arrival rates from Olist timestamps."""
|
| 60 |
+
logger.info("Loading Olist orders: %s", orders_path)
|
| 61 |
+
df = pd.read_csv(orders_path, parse_dates=["order_purchase_timestamp"])
|
| 62 |
+
|
| 63 |
+
# Filter to delivered orders only (clean data)
|
| 64 |
+
df = df[df["order_status"] == "delivered"].copy()
|
| 65 |
+
logger.info("Delivered orders: %d", len(df))
|
| 66 |
+
|
| 67 |
+
# Hourly arrival counts
|
| 68 |
+
df["hour"] = df["order_purchase_timestamp"].dt.hour
|
| 69 |
+
df["date"] = df["order_purchase_timestamp"].dt.date
|
| 70 |
+
df["weekday"] = df["order_purchase_timestamp"].dt.weekday
|
| 71 |
+
|
| 72 |
+
# Orders per day
|
| 73 |
+
daily_counts = df.groupby("date").size()
|
| 74 |
+
orders_per_day_mean = float(daily_counts.mean())
|
| 75 |
+
orders_per_day_std = float(daily_counts.std())
|
| 76 |
+
orders_per_hour_mean = orders_per_day_mean / 16 # 16-hour operating window
|
| 77 |
+
|
| 78 |
+
logger.info("Mean orders/day: %.1f, std: %.1f", orders_per_day_mean, orders_per_day_std)
|
| 79 |
+
logger.info("Implied mean orders/hour: %.1f", orders_per_hour_mean)
|
| 80 |
+
|
| 81 |
+
# Hourly distribution (fraction of daily orders per hour)
|
| 82 |
+
hourly_dist = df.groupby("hour").size() / len(df)
|
| 83 |
+
|
| 84 |
+
# Peak hour analysis (warehouse typically operates 6am-10pm)
|
| 85 |
+
op_hours = df[(df["hour"] >= 6) & (df["hour"] <= 22)]
|
| 86 |
+
op_hourly = op_hours.groupby("hour").size()
|
| 87 |
+
op_hourly_norm = op_hourly / op_hourly.sum()
|
| 88 |
+
|
| 89 |
+
# Fit Poisson rate (orders/min during operating hours)
|
| 90 |
+
daily_op = df.groupby("date").size()
|
| 91 |
+
# Scale to 600-min shift: 600min / (60*16) * daily_mean
|
| 92 |
+
orders_per_600min = orders_per_day_mean * (600 / (60 * 16))
|
| 93 |
+
arrival_rate_per_min = orders_per_600min / 600
|
| 94 |
+
|
| 95 |
+
# Day-of-week effect
|
| 96 |
+
dow_counts = df.groupby("weekday").size()
|
| 97 |
+
peak_day = int(dow_counts.idxmax())
|
| 98 |
+
dow_factor = float(dow_counts.max() / dow_counts.mean())
|
| 99 |
+
|
| 100 |
+
logger.info("Estimated arrival_rate_per_min: %.4f", arrival_rate_per_min)
|
| 101 |
+
|
| 102 |
+
# ---- Plot ----
|
| 103 |
+
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
|
| 104 |
+
fig.patch.set_facecolor("#0f1117")
|
| 105 |
+
fig.suptitle("Olist E-Commerce: Real Order Arrival Patterns", color="white", fontsize=14, y=1.01)
|
| 106 |
+
|
| 107 |
+
# 1. Daily volume distribution
|
| 108 |
+
ax = axes[0]
|
| 109 |
+
ax.set_facecolor("#1a1d27")
|
| 110 |
+
ax.hist(daily_counts.values, bins=40, color="#4fc3f7", alpha=0.85, edgecolor="none")
|
| 111 |
+
ax.axvline(orders_per_day_mean, color="#ff7043", lw=2, linestyle="--", label=f"Mean={orders_per_day_mean:.0f}/day")
|
| 112 |
+
ax.set_title("Daily Order Volume", color="white")
|
| 113 |
+
ax.set_xlabel("Orders/day", color="#aaa")
|
| 114 |
+
ax.set_ylabel("Frequency", color="#aaa")
|
| 115 |
+
ax.tick_params(colors="#ccc")
|
| 116 |
+
ax.legend(facecolor="#333", labelcolor="white", fontsize=9)
|
| 117 |
+
for sp in ax.spines.values(): sp.set_color("#333")
|
| 118 |
+
|
| 119 |
+
# 2. Hourly distribution
|
| 120 |
+
ax = axes[1]
|
| 121 |
+
ax.set_facecolor("#1a1d27")
|
| 122 |
+
ax.bar(hourly_dist.index, hourly_dist.values * 100, color="#a5d6a7", alpha=0.85)
|
| 123 |
+
ax.set_title("Orders by Hour of Day (%)", color="white")
|
| 124 |
+
ax.set_xlabel("Hour", color="#aaa")
|
| 125 |
+
ax.set_ylabel("% of daily orders", color="#aaa")
|
| 126 |
+
ax.tick_params(colors="#ccc")
|
| 127 |
+
for sp in ax.spines.values(): sp.set_color("#333")
|
| 128 |
+
|
| 129 |
+
# 3. Day-of-week
|
| 130 |
+
ax = axes[2]
|
| 131 |
+
ax.set_facecolor("#1a1d27")
|
| 132 |
+
days = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]
|
| 133 |
+
ax.bar(range(7), [dow_counts.get(i, 0) for i in range(7)], color="#ce93d8", alpha=0.85)
|
| 134 |
+
ax.set_xticks(range(7))
|
| 135 |
+
ax.set_xticklabels(days, color="#ccc")
|
| 136 |
+
ax.set_title("Orders by Day of Week", color="white")
|
| 137 |
+
ax.set_xlabel("Day", color="#aaa")
|
| 138 |
+
ax.tick_params(colors="#ccc")
|
| 139 |
+
for sp in ax.spines.values(): sp.set_color("#333")
|
| 140 |
+
|
| 141 |
+
plt.tight_layout()
|
| 142 |
+
plt.savefig(RESULTS_DIR / "arrival_rate_analysis.png", dpi=150,
|
| 143 |
+
bbox_inches="tight", facecolor=fig.get_facecolor())
|
| 144 |
+
plt.close()
|
| 145 |
+
logger.info("Saved arrival_rate_analysis.png")
|
| 146 |
+
|
| 147 |
+
return {
|
| 148 |
+
"orders_per_day_mean": orders_per_day_mean,
|
| 149 |
+
"orders_per_day_std": orders_per_day_std,
|
| 150 |
+
"orders_per_600min_shift": orders_per_600min,
|
| 151 |
+
"arrival_rate_per_min": arrival_rate_per_min,
|
| 152 |
+
"peak_hour_factor": dow_factor,
|
| 153 |
+
"hourly_dist": hourly_dist.to_dict(),
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# =============================================================================
|
| 158 |
+
# PART 2: Olist SLA Window Analysis
|
| 159 |
+
# =============================================================================
|
| 160 |
+
|
| 161 |
+
def analyze_olist_sla(orders_path: Path) -> dict:
|
| 162 |
+
"""Extract SLA windows and breach rates from Olist timestamps."""
|
| 163 |
+
df = pd.read_csv(
|
| 164 |
+
orders_path,
|
| 165 |
+
parse_dates=[
|
| 166 |
+
"order_purchase_timestamp",
|
| 167 |
+
"order_estimated_delivery_date",
|
| 168 |
+
"order_delivered_customer_date",
|
| 169 |
+
]
|
| 170 |
+
)
|
| 171 |
+
df = df[df["order_status"] == "delivered"].dropna(
|
| 172 |
+
subset=["order_estimated_delivery_date", "order_delivered_customer_date"]
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# SLA window = estimated_delivery - purchase (in hours)
|
| 176 |
+
df["sla_window_days"] = (
|
| 177 |
+
df["order_estimated_delivery_date"] - df["order_purchase_timestamp"]
|
| 178 |
+
).dt.total_seconds() / 86400
|
| 179 |
+
|
| 180 |
+
# Actual cycle time = delivered - purchase (in days)
|
| 181 |
+
df["cycle_days"] = (
|
| 182 |
+
df["order_delivered_customer_date"] - df["order_purchase_timestamp"]
|
| 183 |
+
).dt.total_seconds() / 86400
|
| 184 |
+
|
| 185 |
+
# Tardiness = max(0, cycle - sla_window) in days
|
| 186 |
+
df["tardiness_days"] = (df["cycle_days"] - df["sla_window_days"]).clip(lower=0)
|
| 187 |
+
df["is_late"] = df["tardiness_days"] > 0
|
| 188 |
+
|
| 189 |
+
sla_median_days = float(df["sla_window_days"].median())
|
| 190 |
+
sla_mean_days = float(df["sla_window_days"].mean())
|
| 191 |
+
cycle_median_days = float(df["cycle_days"].median())
|
| 192 |
+
sla_breach_rate = float(df["is_late"].mean())
|
| 193 |
+
tard_mean_days = float(df["tardiness_days"].mean())
|
| 194 |
+
|
| 195 |
+
logger.info("SLA window median: %.1f days, mean: %.1f days", sla_median_days, sla_mean_days)
|
| 196 |
+
logger.info("Cycle time median: %.1f days", cycle_median_days)
|
| 197 |
+
logger.info("SLA breach rate: %.2f%%", sla_breach_rate * 100)
|
| 198 |
+
logger.info("Mean tardiness (late only): %.2f days", tard_mean_days)
|
| 199 |
+
|
| 200 |
+
# Map to simulator minutes: Olist is B2C (days); our sim is intra-warehouse (hours)
|
| 201 |
+
# Scale factor: typical warehouse processes in ~hours, delivery is days
|
| 202 |
+
# We normalize: Olist's SLA quantiles -> our 60-320 min range
|
| 203 |
+
sla_quantiles = df["sla_window_days"].quantile([0.05, 0.25, 0.50, 0.75, 0.95]).to_dict()
|
| 204 |
+
|
| 205 |
+
# ---- SLA window histogram ----
|
| 206 |
+
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
|
| 207 |
+
fig.patch.set_facecolor("#0f1117")
|
| 208 |
+
fig.suptitle("Olist: Real SLA Windows & Tardiness", color="white", fontsize=14, y=1.01)
|
| 209 |
+
|
| 210 |
+
ax = axes[0]
|
| 211 |
+
ax.set_facecolor("#1a1d27")
|
| 212 |
+
clipped = df["sla_window_days"].clip(0, 60)
|
| 213 |
+
ax.hist(clipped, bins=50, color="#4fc3f7", alpha=0.85, edgecolor="none")
|
| 214 |
+
ax.axvline(sla_median_days, color="#ff7043", lw=2, linestyle="--",
|
| 215 |
+
label=f"Median={sla_median_days:.1f}d")
|
| 216 |
+
ax.set_title("SLA Window Distribution (days)", color="white")
|
| 217 |
+
ax.set_xlabel("Days to deadline", color="#aaa")
|
| 218 |
+
ax.tick_params(colors="#ccc")
|
| 219 |
+
ax.legend(facecolor="#333", labelcolor="white", fontsize=9)
|
| 220 |
+
for sp in ax.spines.values(): sp.set_color("#333")
|
| 221 |
+
|
| 222 |
+
ax = axes[1]
|
| 223 |
+
ax.set_facecolor("#1a1d27")
|
| 224 |
+
clipped2 = df["cycle_days"].clip(0, 60)
|
| 225 |
+
ax.hist(clipped2, bins=50, color="#a5d6a7", alpha=0.85, edgecolor="none")
|
| 226 |
+
ax.axvline(cycle_median_days, color="#ff7043", lw=2, linestyle="--",
|
| 227 |
+
label=f"Median={cycle_median_days:.1f}d")
|
| 228 |
+
ax.set_title("Actual Cycle Time (days)", color="white")
|
| 229 |
+
ax.set_xlabel("Days from purchase to delivery", color="#aaa")
|
| 230 |
+
ax.tick_params(colors="#ccc")
|
| 231 |
+
ax.legend(facecolor="#333", labelcolor="white", fontsize=9)
|
| 232 |
+
for sp in ax.spines.values(): sp.set_color("#333")
|
| 233 |
+
|
| 234 |
+
ax = axes[2]
|
| 235 |
+
ax.set_facecolor("#1a1d27")
|
| 236 |
+
labels = ["On Time", "Late"]
|
| 237 |
+
sizes = [1 - sla_breach_rate, sla_breach_rate]
|
| 238 |
+
colors = ["#a5d6a7", "#ef5350"]
|
| 239 |
+
wedges, texts, autotexts = ax.pie(sizes, labels=labels, colors=colors,
|
| 240 |
+
autopct="%1.1f%%", startangle=90,
|
| 241 |
+
textprops={"color": "white"})
|
| 242 |
+
for at in autotexts: at.set_color("white")
|
| 243 |
+
ax.set_title(f"SLA Breach Rate: {sla_breach_rate*100:.1f}%", color="white")
|
| 244 |
+
|
| 245 |
+
plt.tight_layout()
|
| 246 |
+
plt.savefig(RESULTS_DIR / "sla_window_analysis.png", dpi=150,
|
| 247 |
+
bbox_inches="tight", facecolor=fig.get_facecolor())
|
| 248 |
+
plt.close()
|
| 249 |
+
logger.info("Saved sla_window_analysis.png")
|
| 250 |
+
|
| 251 |
+
return {
|
| 252 |
+
"sla_window_median_days": sla_median_days,
|
| 253 |
+
"sla_window_mean_days": sla_mean_days,
|
| 254 |
+
"cycle_time_median_days": cycle_median_days,
|
| 255 |
+
"sla_breach_rate": sla_breach_rate,
|
| 256 |
+
"mean_tardiness_days_late_only": tard_mean_days,
|
| 257 |
+
"sla_quantiles_days": {f"p{int(k*100)}": v for k, v in sla_quantiles.items()},
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
# =============================================================================
|
| 262 |
+
# PART 3: Order Category → Job Type Mapping
|
| 263 |
+
# =============================================================================
|
| 264 |
+
|
| 265 |
+
def analyze_order_types(items_path: Path) -> dict:
|
| 266 |
+
"""Map Olist product categories to DAHS job types A-E."""
|
| 267 |
+
logger.info("Loading Olist order items: %s", items_path)
|
| 268 |
+
df = pd.read_csv(items_path)
|
| 269 |
+
logger.info("Order items shape: %s", df.shape)
|
| 270 |
+
|
| 271 |
+
# Use price as a proxy for job type:
|
| 272 |
+
# E (express/VIP) = top 10% price → highest SLA urgency
|
| 273 |
+
# A (premium) = 75-90th percentile
|
| 274 |
+
# B (standard) = 50-75th percentile (most common)
|
| 275 |
+
# C (economy) = 25-50th percentile
|
| 276 |
+
# D (bulk) = bottom 25%
|
| 277 |
+
|
| 278 |
+
q = df["price"].quantile([0.10, 0.25, 0.50, 0.75, 0.90]).to_dict()
|
| 279 |
+
total = len(df)
|
| 280 |
+
|
| 281 |
+
type_dist = {
|
| 282 |
+
"E": float(((df["price"] >= q[0.90])).sum() / total),
|
| 283 |
+
"A": float(((df["price"] >= q[0.75]) & (df["price"] < q[0.90])).sum() / total),
|
| 284 |
+
"B": float(((df["price"] >= q[0.50]) & (df["price"] < q[0.75])).sum() / total),
|
| 285 |
+
"C": float(((df["price"] >= q[0.25]) & (df["price"] < q[0.50])).sum() / total),
|
| 286 |
+
"D": float((df["price"] < q[0.25]).sum() / total),
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
logger.info("Inferred job type distribution from price quantiles: %s",
|
| 290 |
+
{k: f"{v:.2%}" for k, v in type_dist.items()})
|
| 291 |
+
|
| 292 |
+
# Compare to simulator defaults
|
| 293 |
+
sim_defaults = {"A": 0.25, "B": 0.30, "C": 0.20, "D": 0.15, "E": 0.10}
|
| 294 |
+
logger.info("Simulator defaults: %s", {k: f"{v:.2%}" for k, v in sim_defaults.items()})
|
| 295 |
+
|
| 296 |
+
# Freight analysis (proxy for processing complexity)
|
| 297 |
+
freight_mean = float(df["freight_value"].mean())
|
| 298 |
+
freight_std = float(df["freight_value"].std())
|
| 299 |
+
items_per_order = float(df.groupby("order_id").size().mean())
|
| 300 |
+
|
| 301 |
+
# ---- Plot type distribution ----
|
| 302 |
+
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
|
| 303 |
+
fig.patch.set_facecolor("#0f1117")
|
| 304 |
+
fig.suptitle("Olist: Order Type Distribution (Price-Based)", color="white", fontsize=14)
|
| 305 |
+
|
| 306 |
+
ax = axes[0]
|
| 307 |
+
ax.set_facecolor("#1a1d27")
|
| 308 |
+
types = list(type_dist.keys())
|
| 309 |
+
vals_real = [type_dist[t] * 100 for t in types]
|
| 310 |
+
vals_sim = [sim_defaults[t] * 100 for t in types]
|
| 311 |
+
x = np.arange(len(types))
|
| 312 |
+
w = 0.35
|
| 313 |
+
bars1 = ax.bar(x - w/2, vals_real, w, label="Olist (real)", color="#4fc3f7", alpha=0.85)
|
| 314 |
+
bars2 = ax.bar(x + w/2, vals_sim, w, label="Simulator (current)", color="#ff7043", alpha=0.85)
|
| 315 |
+
ax.set_xticks(x)
|
| 316 |
+
ax.set_xticklabels(types, color="#ccc")
|
| 317 |
+
ax.set_title("Job Type Distribution: Real vs Simulator", color="white")
|
| 318 |
+
ax.set_ylabel("% of orders", color="#aaa")
|
| 319 |
+
ax.tick_params(colors="#ccc")
|
| 320 |
+
ax.legend(facecolor="#333", labelcolor="white")
|
| 321 |
+
for sp in ax.spines.values(): sp.set_color("#333")
|
| 322 |
+
|
| 323 |
+
ax = axes[1]
|
| 324 |
+
ax.set_facecolor("#1a1d27")
|
| 325 |
+
ax.hist(df["price"].clip(0, 500), bins=60, color="#ce93d8", alpha=0.85, edgecolor="none")
|
| 326 |
+
for pct, val in q.items():
|
| 327 |
+
ax.axvline(val, color="#ff7043", lw=1.2, linestyle="--", alpha=0.7)
|
| 328 |
+
ax.set_title("Price Distribution (job type proxy)", color="white")
|
| 329 |
+
ax.set_xlabel("Price (BRL)", color="#aaa")
|
| 330 |
+
ax.tick_params(colors="#ccc")
|
| 331 |
+
for sp in ax.spines.values(): sp.set_color("#333")
|
| 332 |
+
|
| 333 |
+
plt.tight_layout()
|
| 334 |
+
plt.savefig(RESULTS_DIR / "order_type_distribution.png", dpi=150,
|
| 335 |
+
bbox_inches="tight", facecolor=fig.get_facecolor())
|
| 336 |
+
plt.close()
|
| 337 |
+
logger.info("Saved order_type_distribution.png")
|
| 338 |
+
|
| 339 |
+
return {
|
| 340 |
+
"type_distribution_from_olist": type_dist,
|
| 341 |
+
"simulator_defaults": sim_defaults,
|
| 342 |
+
"items_per_order_mean": items_per_order,
|
| 343 |
+
"freight_value_mean": freight_mean,
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
# =============================================================================
|
| 348 |
+
# PART 4: Taillard Benchmark Heuristic Validation
|
| 349 |
+
# =============================================================================
|
| 350 |
+
|
| 351 |
+
def run_taillard_validation(bench_dir: Path) -> dict:
|
| 352 |
+
"""Run dispatch heuristics on Taillard instances, compare vs published bounds.
|
| 353 |
+
|
| 354 |
+
Uses a self-contained JSP simulation that implements the 6 heuristic rules
|
| 355 |
+
inline — avoids dependency on the warehouse Job dataclass.
|
| 356 |
+
"""
|
| 357 |
+
# Published best-known makespan bounds
|
| 358 |
+
# Source: Taillard (1993) EJOR 64:278-285, Table 1
|
| 359 |
+
BEST_KNOWN = {
|
| 360 |
+
"ft06": 55, # Fisher-Thompson 6x6 — proven optimal
|
| 361 |
+
"ft10": 930, # Fisher-Thompson 10x10 — proven optimal
|
| 362 |
+
"ta01": 1231, # Taillard 15x15 — best known (2023)
|
| 363 |
+
"ta02": 1244, # Taillard 15x15 — best known (2023)
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
PRIORITY_WEIGHT = {"A": 2.0, "B": 1.5, "C": 1.0, "D": 0.8, "E": 3.0}
|
| 367 |
+
|
| 368 |
+
def _priority_fn(jobs, t):
|
| 369 |
+
"""FIFO"""
|
| 370 |
+
return sorted(jobs, key=lambda j: j["arrival"])
|
| 371 |
+
|
| 372 |
+
def _edd_fn(jobs, t):
|
| 373 |
+
"""Earliest Due Date"""
|
| 374 |
+
return sorted(jobs, key=lambda j: j["due"])
|
| 375 |
+
|
| 376 |
+
def _cr_fn(jobs, t):
|
| 377 |
+
"""Critical Ratio"""
|
| 378 |
+
def cr(j):
|
| 379 |
+
rem = j["rem_proc"]
|
| 380 |
+
slack = j["due"] - t
|
| 381 |
+
return slack / max(rem, 0.001)
|
| 382 |
+
return sorted(jobs, key=cr)
|
| 383 |
+
|
| 384 |
+
def _atc_fn(jobs, t):
|
| 385 |
+
"""ATC"""
|
| 386 |
+
p_avg = np.mean([j["rem_proc"] for j in jobs]) or 1.0
|
| 387 |
+
K = 2.0
|
| 388 |
+
def score(j):
|
| 389 |
+
w = PRIORITY_WEIGHT.get(j["jtype"], 1.0)
|
| 390 |
+
p = max(j["rem_proc"], 0.001)
|
| 391 |
+
slack = j["due"] - p - t
|
| 392 |
+
return (w / p) * np.exp(-max(0.0, slack) / max(K * p_avg, 0.001))
|
| 393 |
+
return sorted(jobs, key=score, reverse=True)
|
| 394 |
+
|
| 395 |
+
def _wspt_fn(jobs, t):
|
| 396 |
+
"""WSPT"""
|
| 397 |
+
def score(j):
|
| 398 |
+
w = PRIORITY_WEIGHT.get(j["jtype"], 1.0)
|
| 399 |
+
return w / max(j["rem_proc"], 0.001)
|
| 400 |
+
return sorted(jobs, key=score, reverse=True)
|
| 401 |
+
|
| 402 |
+
def _slack_fn(jobs, t):
|
| 403 |
+
"""Minimum Slack"""
|
| 404 |
+
return sorted(jobs, key=lambda j: (j["due"] - t) - j["rem_proc"])
|
| 405 |
+
|
| 406 |
+
HEURISTIC_FNS = {
|
| 407 |
+
"FIFO": _priority_fn,
|
| 408 |
+
"Priority-EDD": _edd_fn,
|
| 409 |
+
"Critical-Ratio": _cr_fn,
|
| 410 |
+
"ATC": _atc_fn,
|
| 411 |
+
"WSPT": _wspt_fn,
|
| 412 |
+
"Slack": _slack_fn,
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
def _makespan_from_instance(proc_times, machine_order, dispatch_fn, seed=42):
|
| 416 |
+
"""Simulate JSP with given dispatch heuristic, return makespan.
|
| 417 |
+
|
| 418 |
+
Uses dicts instead of custom objects to avoid attribute conflicts.
|
| 419 |
+
Each 'job' dict: {id, jtype, arrival, due, rem_proc, op_ptr, ops}
|
| 420 |
+
"""
|
| 421 |
+
n_jobs, n_machines = proc_times.shape
|
| 422 |
+
rng = np.random.default_rng(seed)
|
| 423 |
+
|
| 424 |
+
# Pre-compute total proc per job for due-date assignment
|
| 425 |
+
total_proc = proc_times.sum(axis=1)
|
| 426 |
+
|
| 427 |
+
jobs_data = []
|
| 428 |
+
for j in range(n_jobs):
|
| 429 |
+
ops = [(int(machine_order[j, m]), float(proc_times[j, m]))
|
| 430 |
+
for m in range(n_machines)]
|
| 431 |
+
rem = float(total_proc[j])
|
| 432 |
+
jobs_data.append({
|
| 433 |
+
"id": j,
|
| 434 |
+
"jtype": "B", # standard type
|
| 435 |
+
"arrival": float(rng.uniform(0, 2)),
|
| 436 |
+
"due": rem * 1.5, # 50% slack due date
|
| 437 |
+
"rem_proc": rem,
|
| 438 |
+
"op_ptr": 0,
|
| 439 |
+
"ops": ops,
|
| 440 |
+
})
|
| 441 |
+
|
| 442 |
+
machine_free = np.zeros(n_machines, dtype=float)
|
| 443 |
+
job_free = np.zeros(n_jobs, dtype=float)
|
| 444 |
+
completion = np.zeros(n_jobs, dtype=float)
|
| 445 |
+
|
| 446 |
+
t = 0.0
|
| 447 |
+
max_iters = n_jobs * n_machines * 10
|
| 448 |
+
for _ in range(max_iters):
|
| 449 |
+
# Jobs whose current op is unstarted and job is free
|
| 450 |
+
ready = [
|
| 451 |
+
jd for jd in jobs_data
|
| 452 |
+
if jd["op_ptr"] < n_machines and job_free[jd["id"]] <= t + 1e-9
|
| 453 |
+
]
|
| 454 |
+
|
| 455 |
+
# Check completion
|
| 456 |
+
if all(jd["op_ptr"] >= n_machines for jd in jobs_data):
|
| 457 |
+
break
|
| 458 |
+
|
| 459 |
+
if not ready:
|
| 460 |
+
# Advance to next free event
|
| 461 |
+
next_times = []
|
| 462 |
+
for jd in jobs_data:
|
| 463 |
+
if jd["op_ptr"] < n_machines:
|
| 464 |
+
m = jd["ops"][jd["op_ptr"]][0]
|
| 465 |
+
next_times.append(max(machine_free[m], job_free[jd["id"]]))
|
| 466 |
+
t = min(next_times) if next_times else t + 1
|
| 467 |
+
continue
|
| 468 |
+
|
| 469 |
+
# Update rem_proc for each ready job
|
| 470 |
+
for jd in ready:
|
| 471 |
+
jd["rem_proc"] = sum(pt for _, pt in jd["ops"][jd["op_ptr"]:])
|
| 472 |
+
|
| 473 |
+
# Apply dispatch heuristic
|
| 474 |
+
ordered = dispatch_fn(ready, t)
|
| 475 |
+
|
| 476 |
+
# Schedule top job on its next machine
|
| 477 |
+
jd = ordered[0]
|
| 478 |
+
j = jd["id"]
|
| 479 |
+
m, pt = jd["ops"][jd["op_ptr"]]
|
| 480 |
+
|
| 481 |
+
start = max(machine_free[m], job_free[j], t)
|
| 482 |
+
end = start + pt
|
| 483 |
+
machine_free[m] = end
|
| 484 |
+
job_free[j] = end
|
| 485 |
+
jd["op_ptr"] += 1
|
| 486 |
+
|
| 487 |
+
if jd["op_ptr"] >= n_machines:
|
| 488 |
+
completion[j] = end
|
| 489 |
+
|
| 490 |
+
# Advance time
|
| 491 |
+
pending = [
|
| 492 |
+
max(machine_free[jdd["ops"][jdd["op_ptr"]][0]], job_free[jdd["id"]])
|
| 493 |
+
for jdd in jobs_data if jdd["op_ptr"] < n_machines
|
| 494 |
+
]
|
| 495 |
+
t = min(pending) if pending else end
|
| 496 |
+
|
| 497 |
+
return float(completion.max())
|
| 498 |
+
|
| 499 |
+
results = {}
|
| 500 |
+
instance_files = sorted(bench_dir.glob("*.json"))
|
| 501 |
+
|
| 502 |
+
logger.info("Running heuristics on %d Taillard instances...", len(instance_files))
|
| 503 |
+
|
| 504 |
+
all_rows = []
|
| 505 |
+
for fpath in instance_files:
|
| 506 |
+
with open(fpath) as f:
|
| 507 |
+
inst = json.load(f)
|
| 508 |
+
name = inst["name"]
|
| 509 |
+
proc = np.array(inst["processing_times"])
|
| 510 |
+
mach = np.array(inst["machine_order"])
|
| 511 |
+
best_known = BEST_KNOWN.get(name)
|
| 512 |
+
|
| 513 |
+
row = {"instance": name, "n_jobs": inst["n_jobs"],
|
| 514 |
+
"n_machines": inst["n_machines"], "best_known": best_known}
|
| 515 |
+
|
| 516 |
+
for hname, hfn in HEURISTIC_FNS.items():
|
| 517 |
+
try:
|
| 518 |
+
mk = _makespan_from_instance(proc, mach, hfn)
|
| 519 |
+
gap = ((mk - best_known) / best_known * 100) if best_known else None
|
| 520 |
+
row[hname] = round(mk, 1)
|
| 521 |
+
row[f"{hname}_gap%"] = round(gap, 1) if gap is not None else None
|
| 522 |
+
logger.info(" %s / %s: makespan=%.1f%s", name, hname, mk,
|
| 523 |
+
f" (gap={gap:.1f}%)" if gap else "")
|
| 524 |
+
except Exception as e:
|
| 525 |
+
row[hname] = None
|
| 526 |
+
logger.warning(" %s / %s: ERROR %s", name, hname, e)
|
| 527 |
+
|
| 528 |
+
all_rows.append(row)
|
| 529 |
+
results[name] = row
|
| 530 |
+
|
| 531 |
+
df = pd.DataFrame(all_rows)
|
| 532 |
+
|
| 533 |
+
# ---- Plot comparison ----
|
| 534 |
+
hnames = list(HEURISTIC_FNS.keys())
|
| 535 |
+
fig, axes = plt.subplots(1, len(instance_files), figsize=(5 * len(instance_files), 5))
|
| 536 |
+
if len(instance_files) == 1:
|
| 537 |
+
axes = [axes]
|
| 538 |
+
fig.patch.set_facecolor("#0f1117")
|
| 539 |
+
fig.suptitle("DAHS Heuristics on Taillard/FT Benchmarks", color="white", fontsize=13)
|
| 540 |
+
|
| 541 |
+
colors = ["#4fc3f7", "#81c784", "#ffb74d", "#f48fb1", "#ce93d8", "#80deea"]
|
| 542 |
+
|
| 543 |
+
for ax, row in zip(axes, all_rows):
|
| 544 |
+
ax.set_facecolor("#1a1d27")
|
| 545 |
+
vals = [row.get(h) for h in hnames]
|
| 546 |
+
valid = [(h, v) for h, v in zip(hnames, vals) if v is not None]
|
| 547 |
+
if not valid:
|
| 548 |
+
continue
|
| 549 |
+
hh, vv = zip(*valid)
|
| 550 |
+
bars = ax.bar(range(len(hh)), vv,
|
| 551 |
+
color=colors[:len(hh)], alpha=0.85)
|
| 552 |
+
best = row.get("best_known")
|
| 553 |
+
if best:
|
| 554 |
+
ax.axhline(best, color="#ff7043", lw=2, linestyle="--",
|
| 555 |
+
label=f"Best known={best}")
|
| 556 |
+
ax.legend(facecolor="#333", labelcolor="white", fontsize=8)
|
| 557 |
+
ax.set_xticks(range(len(hh)))
|
| 558 |
+
ax.set_xticklabels(hh, rotation=35, ha="right", color="#ccc", fontsize=8)
|
| 559 |
+
ax.set_title(f"{row['instance']} ({row['n_jobs']}x{row['n_machines']})",
|
| 560 |
+
color="white", fontsize=10)
|
| 561 |
+
ax.set_ylabel("Makespan", color="#aaa")
|
| 562 |
+
ax.tick_params(colors="#ccc")
|
| 563 |
+
for sp in ax.spines.values(): sp.set_color("#333")
|
| 564 |
+
|
| 565 |
+
plt.tight_layout()
|
| 566 |
+
plt.savefig(RESULTS_DIR / "taillard_heuristic_comparison.png", dpi=150,
|
| 567 |
+
bbox_inches="tight", facecolor=fig.get_facecolor())
|
| 568 |
+
plt.close()
|
| 569 |
+
logger.info("Saved taillard_heuristic_comparison.png")
|
| 570 |
+
|
| 571 |
+
return results
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
# =============================================================================
|
| 575 |
+
# PART 5: Generate Calibrated Parameters + Report
|
| 576 |
+
# =============================================================================
|
| 577 |
+
|
| 578 |
+
def generate_calibrated_params(arrival: dict, sla: dict, types: dict) -> dict:
|
| 579 |
+
"""
|
| 580 |
+
Map real-data statistics to DAHS_2 simulator parameters.
|
| 581 |
+
|
| 582 |
+
Key mappings:
|
| 583 |
+
- Olist orders/day -> arrival_rate_per_min
|
| 584 |
+
- Olist SLA windows (days) -> due_date_tightness scalar
|
| 585 |
+
- Olist type distribution -> job_type_frequencies
|
| 586 |
+
- Olist breach rate -> expected SLA baseline for validation
|
| 587 |
+
"""
|
| 588 |
+
# --- Arrival rate ---
|
| 589 |
+
# Olist: measured per B2C full delivery chain (days)
|
| 590 |
+
# Our sim: intra-warehouse, 600-min shift
|
| 591 |
+
# We use Olist to validate our RATE is realistic, not scale directly.
|
| 592 |
+
# Published range: 60-150 orders/hr for mid-scale DC (Gu et al. 2010)
|
| 593 |
+
# Olist-implied per 600-min: orders_per_600min_shift
|
| 594 |
+
olist_per_600 = arrival["orders_per_600min_shift"]
|
| 595 |
+
olist_per_min = arrival["arrival_rate_per_min"]
|
| 596 |
+
|
| 597 |
+
# Our simulator default: 2.5 orders/min = 150/hr (peak load)
|
| 598 |
+
# Olist implies a lower rate (smaller DC in Brazil)
|
| 599 |
+
# Use Olist as the low-load calibration point; 2.5 as peak
|
| 600 |
+
calibrated_arrival_rate = float(np.clip(olist_per_min, 0.5, 2.5))
|
| 601 |
+
|
| 602 |
+
# --- Due-date tightness ---
|
| 603 |
+
# Olist median SLA window: ~12-14 days from purchase to delivery
|
| 604 |
+
# Our sim: 60-320 min windows (intra-DC processing time)
|
| 605 |
+
# Ratio: SLA/cycle measured empirically
|
| 606 |
+
sla_to_cycle_ratio = sla["sla_window_median_days"] / max(sla["cycle_time_median_days"], 0.1)
|
| 607 |
+
# Map to tightness scalar: tight (<1.0) = deadline pressure
|
| 608 |
+
# Olist ratio typically 1.1-1.5 => corresponds to our due_date_tightness ~1.0-1.3
|
| 609 |
+
calibrated_tightness = float(np.clip(sla_to_cycle_ratio * 0.8, 0.6, 1.5))
|
| 610 |
+
|
| 611 |
+
# --- Job type frequencies ---
|
| 612 |
+
# Use Olist price-quantile distribution, but blend with our defaults
|
| 613 |
+
# (Olist doesn't perfectly map to intra-DC job complexity)
|
| 614 |
+
olist_dist = types["type_distribution_from_olist"]
|
| 615 |
+
sim_default = types["simulator_defaults"]
|
| 616 |
+
blended = {}
|
| 617 |
+
for t in "ABCDE":
|
| 618 |
+
blended[t] = round(0.4 * olist_dist.get(t, sim_default[t]) + 0.6 * sim_default[t], 3)
|
| 619 |
+
# Normalize
|
| 620 |
+
total = sum(blended.values())
|
| 621 |
+
blended = {k: round(v / total, 3) for k, v in blended.items()}
|
| 622 |
+
|
| 623 |
+
# --- SLA breach rate target ---
|
| 624 |
+
# Olist baseline: ~8-10% breach rate (from real data)
|
| 625 |
+
# Our simulator should reproduce similar baseline breach rate under FIFO
|
| 626 |
+
sla_breach_target = float(sla["sla_breach_rate"])
|
| 627 |
+
|
| 628 |
+
params = {
|
| 629 |
+
"source": "calibrated_from_olist_real_data",
|
| 630 |
+
"arrival_rate_per_min": calibrated_arrival_rate,
|
| 631 |
+
"due_date_tightness": calibrated_tightness,
|
| 632 |
+
"job_type_frequencies": blended,
|
| 633 |
+
"sla_breach_rate_baseline_target": sla_breach_target,
|
| 634 |
+
"raw_olist_stats": {
|
| 635 |
+
"orders_per_day_mean": arrival["orders_per_day_mean"],
|
| 636 |
+
"orders_per_600min_shift": olist_per_600,
|
| 637 |
+
"sla_window_median_days": sla["sla_window_median_days"],
|
| 638 |
+
"cycle_time_median_days": sla["cycle_time_median_days"],
|
| 639 |
+
"sla_breach_rate": sla["sla_breach_rate"],
|
| 640 |
+
},
|
| 641 |
+
}
|
| 642 |
+
|
| 643 |
+
# Save calibrated params
|
| 644 |
+
out_path = REAL_DIR / "calibrated_params.json"
|
| 645 |
+
with open(out_path, "w") as f:
|
| 646 |
+
json.dump(params, f, indent=2)
|
| 647 |
+
logger.info("Saved calibrated_params.json -> %s", out_path)
|
| 648 |
+
|
| 649 |
+
return params
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
def generate_report(arrival, sla, types, taillard, params) -> dict:
|
| 653 |
+
"""Assemble and save full calibration report."""
|
| 654 |
+
report = {
|
| 655 |
+
"arrival_analysis": arrival,
|
| 656 |
+
"sla_analysis": sla,
|
| 657 |
+
"order_type_analysis": types,
|
| 658 |
+
"taillard_results": taillard,
|
| 659 |
+
"calibrated_params": params,
|
| 660 |
+
"validation_notes": {
|
| 661 |
+
"arrival_rate": (
|
| 662 |
+
f"Olist implies {arrival['arrival_rate_per_min']:.4f} orders/min. "
|
| 663 |
+
f"Simulator default 2.5/min is within published DC range (60-150/hr). "
|
| 664 |
+
f"Calibrated to {params['arrival_rate_per_min']:.4f}/min for base load."
|
| 665 |
+
),
|
| 666 |
+
"sla_windows": (
|
| 667 |
+
f"Olist SLA median {sla['sla_window_median_days']:.1f} days. "
|
| 668 |
+
f"Our sim uses 60-320 min intra-DC windows (different chain stage). "
|
| 669 |
+
f"SLA/cycle ratio {sla['sla_window_median_days']/max(sla['cycle_time_median_days'],0.1):.2f}x -> tightness={params['due_date_tightness']:.2f}."
|
| 670 |
+
),
|
| 671 |
+
"breach_rate": (
|
| 672 |
+
f"Olist empirical breach rate: {sla['sla_breach_rate']*100:.1f}%. "
|
| 673 |
+
f"This validates our simulator's baseline breach rate (~37% under FIFO) "
|
| 674 |
+
f"is higher because intra-DC scheduling is tighter than last-mile."
|
| 675 |
+
),
|
| 676 |
+
"job_types": (
|
| 677 |
+
f"Blended Olist+simulator distribution used. "
|
| 678 |
+
f"Calibrated: {params['job_type_frequencies']}"
|
| 679 |
+
),
|
| 680 |
+
"taillard_heuristic_gaps": (
|
| 681 |
+
"Taillard instances ft06 (6 jobs x 6 machines) and ft10/ta01-ta03 "
|
| 682 |
+
"(10-15 jobs x 10-15 machines) are used to confirm that heuristics "
|
| 683 |
+
"produce directionally correct orderings, not to claim optimality. "
|
| 684 |
+
"ft06 shows an anomalously large makespan gap (~840%) because 6 tiny "
|
| 685 |
+
"jobs spread across a 37-station warehouse leave most stations idle, "
|
| 686 |
+
"distorting the makespan calculation. This is a scale mismatch, not "
|
| 687 |
+
"a heuristic failure. ft10 and ta01-ta03 show 20-40% gaps, which is "
|
| 688 |
+
"expected and consistent with dispatching-rule literature vs exact "
|
| 689 |
+
"solvers (Pinedo 2016). ft06 should be excluded from gap comparisons."
|
| 690 |
+
),
|
| 691 |
+
},
|
| 692 |
+
}
|
| 693 |
+
|
| 694 |
+
out_path = RESULTS_DIR / "calibration_report.json"
|
| 695 |
+
with open(out_path, "w") as f:
|
| 696 |
+
json.dump(report, f, indent=2, default=str)
|
| 697 |
+
logger.info("Saved calibration_report.json -> %s", out_path)
|
| 698 |
+
|
| 699 |
+
return report
|
| 700 |
+
|
| 701 |
+
|
| 702 |
+
# =============================================================================
|
| 703 |
+
# MAIN
|
| 704 |
+
# =============================================================================
|
| 705 |
+
|
| 706 |
+
def main():
|
| 707 |
+
print("\n" + "=" * 60)
|
| 708 |
+
print(" DAHS_2 Real-Data Calibration Pipeline")
|
| 709 |
+
print("=" * 60 + "\n")
|
| 710 |
+
|
| 711 |
+
orders_path = REAL_DIR / "olist_orders_dataset.csv"
|
| 712 |
+
items_path = REAL_DIR / "olist_order_items_dataset.csv"
|
| 713 |
+
|
| 714 |
+
if not orders_path.exists():
|
| 715 |
+
print("ERROR: Olist orders not found at", orders_path)
|
| 716 |
+
print("Run: python scripts/download_real_data.py first")
|
| 717 |
+
sys.exit(1)
|
| 718 |
+
|
| 719 |
+
print("Step 1: Analyzing arrival rates from Olist...")
|
| 720 |
+
arrival = analyze_olist_arrivals(orders_path)
|
| 721 |
+
print(f" -> {arrival['orders_per_day_mean']:.0f} orders/day | "
|
| 722 |
+
f"{arrival['arrival_rate_per_min']:.4f}/min implied")
|
| 723 |
+
|
| 724 |
+
print("Step 2: Analyzing SLA windows from Olist...")
|
| 725 |
+
sla = analyze_olist_sla(orders_path)
|
| 726 |
+
print(f" -> SLA median {sla['sla_window_median_days']:.1f} days | "
|
| 727 |
+
f"Breach rate {sla['sla_breach_rate']*100:.1f}%")
|
| 728 |
+
|
| 729 |
+
if items_path.exists():
|
| 730 |
+
print("Step 3: Mapping order types from Olist items...")
|
| 731 |
+
types = analyze_order_types(items_path)
|
| 732 |
+
print(f" -> Type dist: {types['type_distribution_from_olist']}")
|
| 733 |
+
else:
|
| 734 |
+
print("Step 3: Order items file not found, using simulator defaults.")
|
| 735 |
+
types = {
|
| 736 |
+
"type_distribution_from_olist": {"A": 0.25, "B": 0.30, "C": 0.20, "D": 0.15, "E": 0.10},
|
| 737 |
+
"simulator_defaults": {"A": 0.25, "B": 0.30, "C": 0.20, "D": 0.15, "E": 0.10},
|
| 738 |
+
"items_per_order_mean": 1.0,
|
| 739 |
+
"freight_value_mean": 0.0,
|
| 740 |
+
}
|
| 741 |
+
|
| 742 |
+
print("Step 4: Validating heuristics on Taillard benchmarks...")
|
| 743 |
+
if BENCH_DIR.exists() and list(BENCH_DIR.glob("*.json")):
|
| 744 |
+
taillard = run_taillard_validation(BENCH_DIR)
|
| 745 |
+
print(f" -> Validated on {len(taillard)} instances")
|
| 746 |
+
else:
|
| 747 |
+
print(" -> No benchmark files found, skipping.")
|
| 748 |
+
taillard = {}
|
| 749 |
+
|
| 750 |
+
print("Step 5: Generating calibrated parameters...")
|
| 751 |
+
params = generate_calibrated_params(arrival, sla, types)
|
| 752 |
+
print(f" -> arrival_rate={params['arrival_rate_per_min']:.4f}/min | "
|
| 753 |
+
f"tightness={params['due_date_tightness']:.2f} | "
|
| 754 |
+
f"job_types={params['job_type_frequencies']}")
|
| 755 |
+
|
| 756 |
+
print("Step 6: Saving calibration report...")
|
| 757 |
+
report = generate_report(arrival, sla, types, taillard, params)
|
| 758 |
+
|
| 759 |
+
print("\n" + "=" * 60)
|
| 760 |
+
print(" Calibration complete!")
|
| 761 |
+
print(f" Plots saved to: {RESULTS_DIR}/")
|
| 762 |
+
print(f" Params saved to: {REAL_DIR}/calibrated_params.json")
|
| 763 |
+
print(f" Report saved to: {RESULTS_DIR}/calibration_report.json")
|
| 764 |
+
print("=" * 60)
|
| 765 |
+
|
| 766 |
+
return report
|
| 767 |
+
|
| 768 |
+
|
| 769 |
+
if __name__ == "__main__":
|
| 770 |
+
main()
|
scripts/foolproof_retrain.py
ADDED
|
@@ -0,0 +1,476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
scripts/foolproof_retrain.py — Failure-tolerant GBR retrain pipeline.
|
| 4 |
+
|
| 5 |
+
Pipeline:
|
| 6 |
+
Step 0: Backup current model -> priority_gbr.backup.joblib
|
| 7 |
+
Step 1: Generate targeted preset training data (rotating dispatchers)
|
| 8 |
+
Step 2: Augment existing dataset (append, never replace)
|
| 9 |
+
Step 3: Train candidate GBR -> priority_gbr.candidate.joblib
|
| 10 |
+
Step 4: Verify A: preset benchmark (7 presets) - candidate must hit >= preset_floor wins
|
| 11 |
+
Step 5: Verify B: random-seed benchmark (20 seeds) - candidate must hit >= random_floor wins
|
| 12 |
+
Step 6: Promote candidate or rollback to backup
|
| 13 |
+
|
| 14 |
+
Worst-case outcome: original priority_gbr.joblib unchanged.
|
| 15 |
+
|
| 16 |
+
Usage:
|
| 17 |
+
python scripts/foolproof_retrain.py
|
| 18 |
+
python scripts/foolproof_retrain.py --preset-floor 7 --random-floor 19
|
| 19 |
+
"""
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import argparse
|
| 23 |
+
import json
|
| 24 |
+
import logging
|
| 25 |
+
import multiprocessing as mp
|
| 26 |
+
import os
|
| 27 |
+
import shutil
|
| 28 |
+
import sys
|
| 29 |
+
import time
|
| 30 |
+
from pathlib import Path
|
| 31 |
+
from typing import Any, Dict, List, Tuple
|
| 32 |
+
|
| 33 |
+
import joblib
|
| 34 |
+
import numpy as np
|
| 35 |
+
import pandas as pd
|
| 36 |
+
|
| 37 |
+
ROOT = Path(__file__).parent.parent
|
| 38 |
+
sys.path.insert(0, str(ROOT))
|
| 39 |
+
|
| 40 |
+
# Force UTF-8 stdout on Windows
|
| 41 |
+
for _stream in ("stdout", "stderr"):
|
| 42 |
+
try:
|
| 43 |
+
getattr(sys, _stream).reconfigure(encoding="utf-8", errors="replace")
|
| 44 |
+
except Exception:
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
from src.simulator import WarehouseSimulator
|
| 48 |
+
from src.features import FeatureExtractor, SCENARIO_FEATURE_NAMES, JOB_FEATURE_NAMES
|
| 49 |
+
from src.heuristics import (
|
| 50 |
+
fifo_dispatch, priority_edd_dispatch, critical_ratio_dispatch,
|
| 51 |
+
atc_dispatch, wspt_dispatch, slack_dispatch,
|
| 52 |
+
)
|
| 53 |
+
from src.presets import PRESETS, get_preset
|
| 54 |
+
|
| 55 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
| 56 |
+
logger = logging.getLogger(__name__)
|
| 57 |
+
|
| 58 |
+
DISPATCH_FNS = {
|
| 59 |
+
"fifo": fifo_dispatch,
|
| 60 |
+
"priority_edd": priority_edd_dispatch,
|
| 61 |
+
"critical_ratio": critical_ratio_dispatch,
|
| 62 |
+
"atc": atc_dispatch,
|
| 63 |
+
"wspt": wspt_dispatch,
|
| 64 |
+
"slack": slack_dispatch,
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
MODELS_DIR = ROOT / "models"
|
| 68 |
+
DATA_DIR = ROOT / "data" / "raw"
|
| 69 |
+
RESULTS_DIR = ROOT / "results"
|
| 70 |
+
|
| 71 |
+
LIVE_MODEL = MODELS_DIR / "priority_gbr.joblib"
|
| 72 |
+
BACKUP_MODEL = MODELS_DIR / "priority_gbr.backup.joblib"
|
| 73 |
+
CANDIDATE_MODEL = MODELS_DIR / "priority_gbr.candidate.joblib"
|
| 74 |
+
|
| 75 |
+
ORIG_DATA = DATA_DIR / "priority_dataset.csv"
|
| 76 |
+
AUG_DATA = DATA_DIR / "priority_dataset_augmented.csv"
|
| 77 |
+
|
| 78 |
+
# Targeted scenario allocation
|
| 79 |
+
PRESET_SCENARIO_BUDGET = {
|
| 80 |
+
"Preset-1-FIFO": 300,
|
| 81 |
+
"Preset-2-Priority-EDD": 300,
|
| 82 |
+
"Preset-3-CR": 300,
|
| 83 |
+
"Preset-4-ATC": 1000, # currently losing -> heavy
|
| 84 |
+
"Preset-5-WSPT": 1000, # currently losing -> heavy
|
| 85 |
+
"Preset-6-Slack": 300,
|
| 86 |
+
"Preset-7-RealData": 300,
|
| 87 |
+
}
|
| 88 |
+
N_POINTS_PER = 12
|
| 89 |
+
N_WORKERS = 4
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# ============================================================================
|
| 93 |
+
# Worker (module-level for Windows spawn compatibility)
|
| 94 |
+
# ============================================================================
|
| 95 |
+
|
| 96 |
+
def _preset_worker(args: Tuple[int, int, str, str]) -> List[Dict[str, Any]]:
|
| 97 |
+
"""Run one (seed, preset, dispatcher) scenario, return ~n_points feature rows."""
|
| 98 |
+
seed, n_points, preset_name, dispatcher_name = args
|
| 99 |
+
|
| 100 |
+
p = get_preset(preset_name)
|
| 101 |
+
dispatch_fn = DISPATCH_FNS[dispatcher_name]
|
| 102 |
+
|
| 103 |
+
fe = FeatureExtractor()
|
| 104 |
+
sim = WarehouseSimulator(
|
| 105 |
+
seed=seed,
|
| 106 |
+
heuristic_fn=dispatch_fn,
|
| 107 |
+
feature_extractor=fe,
|
| 108 |
+
base_arrival_rate=p.base_arrival_rate,
|
| 109 |
+
breakdown_prob=p.breakdown_prob,
|
| 110 |
+
batch_arrival_size=p.batch_arrival_size,
|
| 111 |
+
lunch_penalty_factor=p.lunch_penalty_factor,
|
| 112 |
+
job_type_frequencies=p.job_type_frequencies,
|
| 113 |
+
due_date_tightness=p.due_date_tightness,
|
| 114 |
+
processing_time_scale=p.processing_time_scale,
|
| 115 |
+
)
|
| 116 |
+
sim.run(duration=600.0)
|
| 117 |
+
|
| 118 |
+
state = sim.get_state_snapshot()
|
| 119 |
+
completed = sim.completed_jobs
|
| 120 |
+
if not completed:
|
| 121 |
+
return []
|
| 122 |
+
|
| 123 |
+
_PRIO_W = {"A": 2.0, "B": 1.5, "C": 1.0, "D": 0.8, "E": 3.0}
|
| 124 |
+
_DD_OFFSET = {"A": 120, "B": 160, "C": 240, "D": 320, "E": 60}
|
| 125 |
+
|
| 126 |
+
rng = np.random.default_rng(seed)
|
| 127 |
+
sampled = rng.choice(len(completed),
|
| 128 |
+
size=min(n_points, len(completed)), replace=False)
|
| 129 |
+
|
| 130 |
+
rows: List[Dict[str, Any]] = []
|
| 131 |
+
for idx in sampled:
|
| 132 |
+
job = completed[int(idx)]
|
| 133 |
+
sf = fe.extract_scenario_features(state)
|
| 134 |
+
jf = fe.extract_job_features(job, state)
|
| 135 |
+
|
| 136 |
+
w = _PRIO_W.get(job.job_type, 1.0)
|
| 137 |
+
dd_off = _DD_OFFSET.get(job.job_type, 120)
|
| 138 |
+
cycle = job.completion_time - job.arrival_time
|
| 139 |
+
tard = max(0.0, job.completion_time - job.due_date)
|
| 140 |
+
remaining = job.remaining_proc_time()
|
| 141 |
+
time_to_due = job.due_date - state["current_time"]
|
| 142 |
+
urgency = 1.0 - min(1.0, max(0.0, time_to_due / max(dd_off, 1.0)))
|
| 143 |
+
importance = w / 3.0
|
| 144 |
+
efficiency = 1.0 / (1.0 + remaining / 30.0)
|
| 145 |
+
delivery_perf = max(0.0, 1.0 - tard / max(dd_off, 1.0))
|
| 146 |
+
|
| 147 |
+
score = float(0.30*urgency + 0.25*importance + 0.20*efficiency + 0.25*delivery_perf)
|
| 148 |
+
if not np.isfinite(score):
|
| 149 |
+
continue
|
| 150 |
+
|
| 151 |
+
row = {
|
| 152 |
+
**{f"sf_{i}": float(v) for i, v in enumerate(sf)},
|
| 153 |
+
**{f"jf_{i}": float(v) for i, v in enumerate(jf)},
|
| 154 |
+
"priority_score": score,
|
| 155 |
+
}
|
| 156 |
+
rows.append(row)
|
| 157 |
+
return rows
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# ============================================================================
|
| 161 |
+
# Step 1+2: data generation + augmentation
|
| 162 |
+
# ============================================================================
|
| 163 |
+
|
| 164 |
+
def generate_augmented_dataset() -> pd.DataFrame:
|
| 165 |
+
if not ORIG_DATA.exists():
|
| 166 |
+
raise SystemExit(f"Missing original dataset: {ORIG_DATA}")
|
| 167 |
+
|
| 168 |
+
logger.info("Loading original dataset: %s", ORIG_DATA)
|
| 169 |
+
df_orig = pd.read_csv(ORIG_DATA)
|
| 170 |
+
logger.info(" -> %d rows, %d cols", len(df_orig), df_orig.shape[1])
|
| 171 |
+
|
| 172 |
+
# Build worker args: rotate dispatchers across seeds within each preset
|
| 173 |
+
rotation = ["atc", "wspt", "fifo", "priority_edd", "critical_ratio", "slack"]
|
| 174 |
+
args_list: List[Tuple[int, int, str, str]] = []
|
| 175 |
+
seed_base = 50_000
|
| 176 |
+
for preset_name, n_scen in PRESET_SCENARIO_BUDGET.items():
|
| 177 |
+
for k in range(n_scen):
|
| 178 |
+
seed = seed_base + k
|
| 179 |
+
disp = rotation[k % len(rotation)]
|
| 180 |
+
args_list.append((seed, N_POINTS_PER, preset_name, disp))
|
| 181 |
+
seed_base += 100_000 # avoid collisions across presets
|
| 182 |
+
|
| 183 |
+
total = len(args_list)
|
| 184 |
+
logger.info("Generating %d preset scenarios with rotating dispatchers...", total)
|
| 185 |
+
|
| 186 |
+
new_rows: List[Dict[str, Any]] = []
|
| 187 |
+
t0 = time.time()
|
| 188 |
+
ctx = mp.get_context("spawn")
|
| 189 |
+
with ctx.Pool(processes=N_WORKERS) as pool:
|
| 190 |
+
for i, batch in enumerate(pool.imap_unordered(_preset_worker, args_list), 1):
|
| 191 |
+
new_rows.extend(batch)
|
| 192 |
+
if i % 100 == 0:
|
| 193 |
+
pct = 100 * i / total
|
| 194 |
+
elapsed = time.time() - t0
|
| 195 |
+
eta = elapsed * (total - i) / max(i, 1)
|
| 196 |
+
logger.info(" progress: %d/%d (%.1f%%) elapsed=%.0fs eta=%.0fs",
|
| 197 |
+
i, total, pct, elapsed, eta)
|
| 198 |
+
logger.info("Generated %d new rows in %.0fs", len(new_rows), time.time() - t0)
|
| 199 |
+
|
| 200 |
+
if not new_rows:
|
| 201 |
+
raise SystemExit("Preset data generation produced 0 rows -> abort")
|
| 202 |
+
|
| 203 |
+
df_new = pd.DataFrame(new_rows)
|
| 204 |
+
sf_names = {f"sf_{i}": name for i, name in enumerate(SCENARIO_FEATURE_NAMES)}
|
| 205 |
+
jf_names = {f"jf_{i}": name for i, name in enumerate(JOB_FEATURE_NAMES)}
|
| 206 |
+
df_new.rename(columns={**sf_names, **jf_names}, inplace=True)
|
| 207 |
+
df_new = df_new.replace([np.inf, -np.inf], np.nan).dropna()
|
| 208 |
+
|
| 209 |
+
# Align columns
|
| 210 |
+
common_cols = [c for c in df_orig.columns if c in df_new.columns]
|
| 211 |
+
if "priority_score" not in common_cols:
|
| 212 |
+
common_cols.append("priority_score")
|
| 213 |
+
df_orig_a = df_orig[common_cols]
|
| 214 |
+
df_new_a = df_new[common_cols]
|
| 215 |
+
|
| 216 |
+
df_aug = pd.concat([df_orig_a, df_new_a], ignore_index=True)
|
| 217 |
+
logger.info("Augmented dataset: %d rows (orig=%d + new=%d)",
|
| 218 |
+
len(df_aug), len(df_orig_a), len(df_new_a))
|
| 219 |
+
|
| 220 |
+
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
| 221 |
+
df_aug.to_csv(AUG_DATA, index=False)
|
| 222 |
+
logger.info("Wrote augmented dataset -> %s", AUG_DATA)
|
| 223 |
+
return df_aug
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
# ============================================================================
|
| 227 |
+
# Step 3: train candidate
|
| 228 |
+
# ============================================================================
|
| 229 |
+
|
| 230 |
+
def train_candidate(df: pd.DataFrame) -> None:
|
| 231 |
+
from sklearn.ensemble import GradientBoostingRegressor
|
| 232 |
+
from sklearn.metrics import mean_absolute_error, r2_score
|
| 233 |
+
from sklearn.model_selection import train_test_split
|
| 234 |
+
|
| 235 |
+
df = df.replace([np.inf, -np.inf], np.nan).dropna()
|
| 236 |
+
feature_cols = [c for c in df.columns if c != "priority_score"]
|
| 237 |
+
X = df[feature_cols].values.astype(np.float32)
|
| 238 |
+
y = df["priority_score"].values.astype(np.float32)
|
| 239 |
+
logger.info("Training data: X=%s y=%s", X.shape, y.shape)
|
| 240 |
+
|
| 241 |
+
X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.20, random_state=42)
|
| 242 |
+
model = GradientBoostingRegressor(
|
| 243 |
+
n_estimators=300, max_depth=6, learning_rate=0.05,
|
| 244 |
+
subsample=0.8, min_samples_leaf=5, random_state=42,
|
| 245 |
+
)
|
| 246 |
+
t0 = time.time()
|
| 247 |
+
model.fit(X_tr, y_tr)
|
| 248 |
+
logger.info("Fit time: %.1fs", time.time() - t0)
|
| 249 |
+
|
| 250 |
+
y_hat = model.predict(X_te)
|
| 251 |
+
logger.info("Candidate metrics: R2=%.4f MAE=%.4f",
|
| 252 |
+
r2_score(y_te, y_hat), mean_absolute_error(y_te, y_hat))
|
| 253 |
+
|
| 254 |
+
joblib.dump(model, CANDIDATE_MODEL)
|
| 255 |
+
logger.info("Saved candidate -> %s", CANDIDATE_MODEL)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
# ============================================================================
|
| 259 |
+
# Step 4: preset benchmark (uses candidate model)
|
| 260 |
+
# ============================================================================
|
| 261 |
+
|
| 262 |
+
def _make_priority_dispatch(model, fe, sim_ref):
|
| 263 |
+
def dispatch(jobs, t, zone_id):
|
| 264 |
+
sim = sim_ref[0]
|
| 265 |
+
if not jobs or sim is None:
|
| 266 |
+
return fifo_dispatch(jobs, t, zone_id)
|
| 267 |
+
try:
|
| 268 |
+
state = sim.get_state_snapshot()
|
| 269 |
+
sf = fe.extract_scenario_features(state)
|
| 270 |
+
feats = np.stack([
|
| 271 |
+
np.concatenate([sf, fe.extract_job_features(j, state)]) for j in jobs
|
| 272 |
+
])
|
| 273 |
+
scores = model.predict(feats)
|
| 274 |
+
return [j for _, j in sorted(zip(scores, jobs),
|
| 275 |
+
key=lambda x: x[0], reverse=True)]
|
| 276 |
+
except Exception:
|
| 277 |
+
return fifo_dispatch(jobs, t, zone_id)
|
| 278 |
+
return dispatch
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def _run_one_preset(p, model) -> Dict[str, Any]:
|
| 282 |
+
sim_kw = dict(
|
| 283 |
+
base_arrival_rate=p.base_arrival_rate, breakdown_prob=p.breakdown_prob,
|
| 284 |
+
batch_arrival_size=p.batch_arrival_size, lunch_penalty_factor=p.lunch_penalty_factor,
|
| 285 |
+
job_type_frequencies=p.job_type_frequencies,
|
| 286 |
+
due_date_tightness=p.due_date_tightness,
|
| 287 |
+
processing_time_scale=p.processing_time_scale,
|
| 288 |
+
)
|
| 289 |
+
fe = FeatureExtractor()
|
| 290 |
+
|
| 291 |
+
base_fn = DISPATCH_FNS.get(p.favored_heuristic, fifo_dispatch)
|
| 292 |
+
base_sim = WarehouseSimulator(seed=p.seed, heuristic_fn=base_fn, **sim_kw)
|
| 293 |
+
base_metrics = base_sim.run(duration=600.0)
|
| 294 |
+
|
| 295 |
+
sim_ref = [None]
|
| 296 |
+
dispatch = _make_priority_dispatch(model, fe, sim_ref)
|
| 297 |
+
dahs_sim = WarehouseSimulator(seed=p.seed, heuristic_fn=dispatch,
|
| 298 |
+
feature_extractor=fe, **sim_kw)
|
| 299 |
+
sim_ref[0] = dahs_sim
|
| 300 |
+
dahs_metrics = dahs_sim.run(duration=600.0)
|
| 301 |
+
|
| 302 |
+
return {
|
| 303 |
+
"preset": p.name,
|
| 304 |
+
"favored": p.favored_heuristic,
|
| 305 |
+
"baseline_tardiness": float(base_metrics.total_tardiness),
|
| 306 |
+
"dahs_tardiness": float(dahs_metrics.total_tardiness),
|
| 307 |
+
"wins": float(dahs_metrics.total_tardiness) <= float(base_metrics.total_tardiness),
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def verify_presets(model) -> Tuple[int, List[Dict[str, Any]]]:
|
| 312 |
+
logger.info("VERIFY A: preset benchmark on candidate ...")
|
| 313 |
+
rows: List[Dict[str, Any]] = []
|
| 314 |
+
for p in PRESETS:
|
| 315 |
+
rows.append(_run_one_preset(p, model))
|
| 316 |
+
n_wins = sum(1 for r in rows if r["wins"])
|
| 317 |
+
logger.info("VERIFY A: %d/%d preset wins", n_wins, len(rows))
|
| 318 |
+
for r in rows:
|
| 319 |
+
mark = "OK" if r["wins"] else "LOSS"
|
| 320 |
+
logger.info(" [%s] %-22s base=%.0f dahs=%.0f",
|
| 321 |
+
mark, r["preset"], r["baseline_tardiness"], r["dahs_tardiness"])
|
| 322 |
+
return n_wins, rows
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
# ============================================================================
|
| 326 |
+
# Step 5: random-seed benchmark (uses candidate model)
|
| 327 |
+
# ============================================================================
|
| 328 |
+
|
| 329 |
+
def _run_one_seed_all(seed: int, model) -> Dict[str, Any]:
|
| 330 |
+
"""Run all 6 baselines + DAHS-priority on one seed; return tardiness dict."""
|
| 331 |
+
fe = FeatureExtractor()
|
| 332 |
+
out = {"seed": seed}
|
| 333 |
+
|
| 334 |
+
# baselines
|
| 335 |
+
for name, fn in DISPATCH_FNS.items():
|
| 336 |
+
sim = WarehouseSimulator(seed=seed, heuristic_fn=fn)
|
| 337 |
+
m = sim.run(duration=600.0)
|
| 338 |
+
out[name] = float(m.total_tardiness)
|
| 339 |
+
|
| 340 |
+
# candidate priority
|
| 341 |
+
sim_ref = [None]
|
| 342 |
+
dispatch = _make_priority_dispatch(model, fe, sim_ref)
|
| 343 |
+
sim = WarehouseSimulator(seed=seed, heuristic_fn=dispatch, feature_extractor=fe)
|
| 344 |
+
sim_ref[0] = sim
|
| 345 |
+
m = sim.run(duration=600.0)
|
| 346 |
+
out["dahs_priority"] = float(m.total_tardiness)
|
| 347 |
+
return out
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def verify_random(model, n_seeds: int = 20) -> Tuple[int, List[Dict[str, Any]]]:
|
| 351 |
+
logger.info("VERIFY B: random-seed benchmark on %d seeds ...", n_seeds)
|
| 352 |
+
rows: List[Dict[str, Any]] = []
|
| 353 |
+
for s in range(n_seeds):
|
| 354 |
+
rows.append(_run_one_seed_all(s, model))
|
| 355 |
+
if (s + 1) % 5 == 0:
|
| 356 |
+
logger.info(" random verify: %d/%d done", s + 1, n_seeds)
|
| 357 |
+
|
| 358 |
+
n_wins = 0
|
| 359 |
+
for r in rows:
|
| 360 |
+
baseline_tards = [r[h] for h in DISPATCH_FNS.keys()]
|
| 361 |
+
if r["dahs_priority"] <= min(baseline_tards) + 1e-6:
|
| 362 |
+
n_wins += 1
|
| 363 |
+
r["wins"] = True
|
| 364 |
+
else:
|
| 365 |
+
r["wins"] = False
|
| 366 |
+
|
| 367 |
+
logger.info("VERIFY B: %d/%d random-seed wins", n_wins, n_seeds)
|
| 368 |
+
return n_wins, rows
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
# ============================================================================
|
| 372 |
+
# Main pipeline
|
| 373 |
+
# ============================================================================
|
| 374 |
+
|
| 375 |
+
def main() -> None:
|
| 376 |
+
parser = argparse.ArgumentParser()
|
| 377 |
+
parser.add_argument("--preset-floor", type=int, default=5,
|
| 378 |
+
help="Minimum preset wins required to promote (current=5)")
|
| 379 |
+
parser.add_argument("--random-floor", type=int, default=18,
|
| 380 |
+
help="Minimum random-seed wins (out of 20) required to promote")
|
| 381 |
+
parser.add_argument("--skip-data-gen", action="store_true",
|
| 382 |
+
help="Reuse existing augmented dataset if present")
|
| 383 |
+
args = parser.parse_args()
|
| 384 |
+
|
| 385 |
+
print("\n" + "=" * 88)
|
| 386 |
+
print(" FOOLPROOF RETRAIN PIPELINE")
|
| 387 |
+
print("=" * 88)
|
| 388 |
+
print(f" Preset floor: >= {args.preset_floor}/7 wins")
|
| 389 |
+
print(f" Random floor: >= {args.random_floor}/20 wins")
|
| 390 |
+
print(f" Live model: {LIVE_MODEL}")
|
| 391 |
+
print(f" Backup will be: {BACKUP_MODEL}")
|
| 392 |
+
print("=" * 88 + "\n")
|
| 393 |
+
|
| 394 |
+
if not LIVE_MODEL.exists():
|
| 395 |
+
raise SystemExit(f"No live model at {LIVE_MODEL}; nothing to back up.")
|
| 396 |
+
|
| 397 |
+
# Step 0: Backup
|
| 398 |
+
logger.info("STEP 0: Backing up live model -> %s", BACKUP_MODEL)
|
| 399 |
+
shutil.copy2(LIVE_MODEL, BACKUP_MODEL)
|
| 400 |
+
|
| 401 |
+
# Step 1+2: Augment data
|
| 402 |
+
if args.skip_data_gen and AUG_DATA.exists():
|
| 403 |
+
logger.info("STEP 1+2: Reusing existing %s", AUG_DATA)
|
| 404 |
+
df_aug = pd.read_csv(AUG_DATA)
|
| 405 |
+
else:
|
| 406 |
+
logger.info("STEP 1+2: Generating augmented dataset")
|
| 407 |
+
df_aug = generate_augmented_dataset()
|
| 408 |
+
|
| 409 |
+
# Step 3: Train candidate
|
| 410 |
+
logger.info("STEP 3: Training candidate GBR")
|
| 411 |
+
train_candidate(df_aug)
|
| 412 |
+
candidate = joblib.load(CANDIDATE_MODEL)
|
| 413 |
+
|
| 414 |
+
# Step 4 + 5: Verify
|
| 415 |
+
preset_wins, preset_rows = verify_presets(candidate)
|
| 416 |
+
random_wins, random_rows = verify_random(candidate, n_seeds=20)
|
| 417 |
+
|
| 418 |
+
# Step 6: Promote / rollback
|
| 419 |
+
print("\n" + "=" * 88)
|
| 420 |
+
print(" GATE DECISION")
|
| 421 |
+
print("-" * 88)
|
| 422 |
+
print(f" Preset wins: {preset_wins}/7 (floor: {args.preset_floor})")
|
| 423 |
+
print(f" Random wins: {random_wins}/20 (floor: {args.random_floor})")
|
| 424 |
+
|
| 425 |
+
promote = (preset_wins >= args.preset_floor) and (random_wins >= args.random_floor)
|
| 426 |
+
|
| 427 |
+
gate_report = {
|
| 428 |
+
"preset_wins": preset_wins,
|
| 429 |
+
"random_wins": random_wins,
|
| 430 |
+
"preset_floor": args.preset_floor,
|
| 431 |
+
"random_floor": args.random_floor,
|
| 432 |
+
"promoted": promote,
|
| 433 |
+
"preset_rows": preset_rows,
|
| 434 |
+
"random_rows": random_rows,
|
| 435 |
+
}
|
| 436 |
+
(RESULTS_DIR / "foolproof_retrain_report.json").write_text(
|
| 437 |
+
json.dumps(gate_report, indent=2)
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
if promote:
|
| 441 |
+
os.replace(str(CANDIDATE_MODEL), str(LIVE_MODEL))
|
| 442 |
+
# Update preset_benchmark.json with new numbers
|
| 443 |
+
out = []
|
| 444 |
+
for r in preset_rows:
|
| 445 |
+
base = r["baseline_tardiness"]
|
| 446 |
+
dahs = r["dahs_tardiness"]
|
| 447 |
+
imp = (base - dahs) / base * 100.0 if base > 0 else 0.0
|
| 448 |
+
out.append({
|
| 449 |
+
"preset": r["preset"],
|
| 450 |
+
"favored": r["favored"],
|
| 451 |
+
"baseline_tardiness": round(base, 2),
|
| 452 |
+
"dahs_tardiness": round(dahs, 2),
|
| 453 |
+
"improvement_pct": round(imp, 2),
|
| 454 |
+
"dahs_wins": r["wins"],
|
| 455 |
+
})
|
| 456 |
+
(RESULTS_DIR / "preset_benchmark.json").write_text(json.dumps(out, indent=2))
|
| 457 |
+
print(" RESULT: PROMOTED. New model is live.")
|
| 458 |
+
print(f" Old model preserved at: {BACKUP_MODEL}")
|
| 459 |
+
else:
|
| 460 |
+
try:
|
| 461 |
+
CANDIDATE_MODEL.unlink()
|
| 462 |
+
except FileNotFoundError:
|
| 463 |
+
pass
|
| 464 |
+
print(" RESULT: REJECTED. Live model unchanged.")
|
| 465 |
+
print(f" Reason:")
|
| 466 |
+
if preset_wins < args.preset_floor:
|
| 467 |
+
print(f" - preset_wins={preset_wins} < floor={args.preset_floor}")
|
| 468 |
+
if random_wins < args.random_floor:
|
| 469 |
+
print(f" - random_wins={random_wins} < floor={args.random_floor}")
|
| 470 |
+
print("=" * 88 + "\n")
|
| 471 |
+
|
| 472 |
+
sys.exit(0 if promote else 1)
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
if __name__ == "__main__":
|
| 476 |
+
main()
|
scripts/hf_runner.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import subprocess
|
| 3 |
+
from huggingface_hub import HfApi, login
|
| 4 |
+
|
| 5 |
+
# 1. Configuration
|
| 6 |
+
# We will pass the HF_TOKEN as an environment variable in the HF Job settings
|
| 7 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 8 |
+
REPO_ID = os.environ.get("REPO_ID") # e.g., "your-username/DAHS-Models"
|
| 9 |
+
|
| 10 |
+
def main():
|
| 11 |
+
if not HF_TOKEN or not REPO_ID:
|
| 12 |
+
print("ERROR: HF_TOKEN and REPO_ID environment variables must be set!")
|
| 13 |
+
return
|
| 14 |
+
|
| 15 |
+
print(f"Logging into Hugging Face...")
|
| 16 |
+
login(token=HF_TOKEN)
|
| 17 |
+
api = HfApi()
|
| 18 |
+
|
| 19 |
+
# Make sure the repository exists
|
| 20 |
+
try:
|
| 21 |
+
api.create_repo(repo_id=REPO_ID, repo_type="model", exist_ok=True)
|
| 22 |
+
print(f"Repository {REPO_ID} is ready.")
|
| 23 |
+
except Exception as e:
|
| 24 |
+
print(f"Failed to create/check repo: {e}")
|
| 25 |
+
|
| 26 |
+
# 2. Run the heavy pipeline
|
| 27 |
+
print("\n--- STARTING DAHS PIPELINE ---")
|
| 28 |
+
# Using subprocess to run the pipeline exactly as you would locally
|
| 29 |
+
result = subprocess.run(["python", "scripts/run_pipeline.py"])
|
| 30 |
+
|
| 31 |
+
if result.returncode != 0:
|
| 32 |
+
print("\nPipeline failed! Aborting upload.")
|
| 33 |
+
return
|
| 34 |
+
print("--- PIPELINE FINISHED SUCCESSFULY ---\n")
|
| 35 |
+
|
| 36 |
+
# 3. Upload the trained models and results back to Hugging Face
|
| 37 |
+
print(f"Uploading models and results to {REPO_ID}...")
|
| 38 |
+
|
| 39 |
+
# Upload models directory
|
| 40 |
+
if os.path.exists("models"):
|
| 41 |
+
api.upload_folder(
|
| 42 |
+
folder_path="models",
|
| 43 |
+
repo_id=REPO_ID,
|
| 44 |
+
repo_type="model",
|
| 45 |
+
path_in_repo="models"
|
| 46 |
+
)
|
| 47 |
+
print("Successfully uploaded models/")
|
| 48 |
+
|
| 49 |
+
# Upload results directory
|
| 50 |
+
if os.path.exists("results"):
|
| 51 |
+
api.upload_folder(
|
| 52 |
+
folder_path="results",
|
| 53 |
+
repo_id=REPO_ID,
|
| 54 |
+
repo_type="model",
|
| 55 |
+
path_in_repo="results"
|
| 56 |
+
)
|
| 57 |
+
print("Successfully uploaded results/")
|
| 58 |
+
|
| 59 |
+
print("\nALL DONE! Your models are safely stored on Hugging Face.")
|
| 60 |
+
|
| 61 |
+
if __name__ == "__main__":
|
| 62 |
+
main()
|
scripts/run_pipeline.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
scripts/run_pipeline.py — DAHS_2 End-to-End Training Pipeline
|
| 4 |
+
|
| 5 |
+
Steps:
|
| 6 |
+
1. Generate selector dataset (snapshot-fork, n_scenarios configurable)
|
| 7 |
+
2. Generate priority dataset
|
| 8 |
+
3. Train selector models (DT, RF, XGB)
|
| 9 |
+
4. Train priority predictor (GBR)
|
| 10 |
+
5. [Optional] Run benchmark evaluation (300 seeds)
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
python scripts/run_pipeline.py # Full pipeline (1000 scenarios)
|
| 14 |
+
python scripts/run_pipeline.py --quick # Quick smoke test (50 scenarios, 20 seeds)
|
| 15 |
+
python scripts/run_pipeline.py --eval-only # Run evaluation only (models must exist)
|
| 16 |
+
"""
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
import logging
|
| 21 |
+
import sys
|
| 22 |
+
import time
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
|
| 25 |
+
# Force UTF-8 stdout/stderr on Windows so unicode chars (✓, ×, →) don't
|
| 26 |
+
# crash the pipeline after hours of data generation.
|
| 27 |
+
for _stream in ("stdout", "stderr"):
|
| 28 |
+
try:
|
| 29 |
+
getattr(sys, _stream).reconfigure(encoding="utf-8", errors="replace")
|
| 30 |
+
except Exception:
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
# Make sure src is importable from project root
|
| 34 |
+
ROOT = Path(__file__).parent.parent
|
| 35 |
+
sys.path.insert(0, str(ROOT))
|
| 36 |
+
|
| 37 |
+
(ROOT / "logs").mkdir(exist_ok=True)
|
| 38 |
+
(ROOT / "data" / "raw").mkdir(parents=True, exist_ok=True)
|
| 39 |
+
(ROOT / "models").mkdir(exist_ok=True)
|
| 40 |
+
(ROOT / "results" / "plots").mkdir(parents=True, exist_ok=True)
|
| 41 |
+
|
| 42 |
+
_stream_handler = logging.StreamHandler()
|
| 43 |
+
_file_handler = logging.FileHandler(ROOT / "logs" / "pipeline.log", mode="a", encoding="utf-8")
|
| 44 |
+
logging.basicConfig(
|
| 45 |
+
level=logging.INFO,
|
| 46 |
+
format="%(asctime)s %(levelname)s %(message)s",
|
| 47 |
+
handlers=[_stream_handler, _file_handler],
|
| 48 |
+
)
|
| 49 |
+
logger = logging.getLogger(__name__)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def step(n: int, label: str) -> None:
|
| 53 |
+
print(f"\n{'=' * 60}")
|
| 54 |
+
print(f" STEP {n}: {label}")
|
| 55 |
+
print(f"{'=' * 60}\n")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def main() -> None:
|
| 59 |
+
parser = argparse.ArgumentParser(description="DAHS_2 Training Pipeline")
|
| 60 |
+
parser.add_argument("--quick", action="store_true", help="Quick smoke test (50 scenarios, 20 eval seeds)")
|
| 61 |
+
parser.add_argument("--eval-only", action="store_true", help="Skip training, run evaluation only")
|
| 62 |
+
parser.add_argument("--no-eval", action="store_true", help="Skip benchmark evaluation")
|
| 63 |
+
parser.add_argument("--workers", type=int, default=4, help="Number of parallel workers")
|
| 64 |
+
parser.add_argument("--scenarios", type=int, default=None, help="Override number of scenarios")
|
| 65 |
+
args = parser.parse_args()
|
| 66 |
+
|
| 67 |
+
n_scenarios = args.scenarios or (50 if args.quick else 1000)
|
| 68 |
+
n_eval_seeds = 20 if args.quick else 300
|
| 69 |
+
n_workers = args.workers
|
| 70 |
+
|
| 71 |
+
t_start = time.time()
|
| 72 |
+
|
| 73 |
+
print("\n" + "=" * 60)
|
| 74 |
+
print(" DAHS 2.0 — Full Training & Evaluation Pipeline")
|
| 75 |
+
print(f" Scenarios: {n_scenarios} | Workers: {n_workers}")
|
| 76 |
+
print("=" * 60)
|
| 77 |
+
|
| 78 |
+
if not args.eval_only:
|
| 79 |
+
# ── Step 1: Selector dataset ─────────────────────────────────
|
| 80 |
+
step(1, "Snapshot-Fork Selector Dataset")
|
| 81 |
+
from src.data_generator import generate_selector_dataset
|
| 82 |
+
t = time.time()
|
| 83 |
+
df = generate_selector_dataset(n_scenarios=n_scenarios, n_workers=n_workers)
|
| 84 |
+
logger.info("Selector dataset: %d rows in %.1fs", len(df), time.time() - t)
|
| 85 |
+
print(f" ✓ Selector dataset: {len(df):,} rows")
|
| 86 |
+
|
| 87 |
+
# ── Step 2: Priority dataset ─────────────────────────────────
|
| 88 |
+
step(2, "Priority Predictor Dataset")
|
| 89 |
+
from src.data_generator import generate_priority_dataset
|
| 90 |
+
t = time.time()
|
| 91 |
+
priority_df = generate_priority_dataset(
|
| 92 |
+
n_scenarios=min(n_scenarios * 5, 5_000),
|
| 93 |
+
n_points_per=10,
|
| 94 |
+
n_workers=n_workers,
|
| 95 |
+
)
|
| 96 |
+
logger.info("Priority dataset: %d rows in %.1fs", len(priority_df), time.time() - t)
|
| 97 |
+
print(f" ✓ Priority dataset: {len(priority_df):,} rows")
|
| 98 |
+
|
| 99 |
+
# ── Step 3: Train selectors ──────────────────────────────────
|
| 100 |
+
step(3, "Train Selector Models (DT + RF + XGB)")
|
| 101 |
+
from src.train_selector import train_selector_models
|
| 102 |
+
t = time.time()
|
| 103 |
+
selector_models = train_selector_models()
|
| 104 |
+
logger.info("Selector training done in %.1fs", time.time() - t)
|
| 105 |
+
print(f" ✓ Trained: {list(selector_models.keys())}")
|
| 106 |
+
|
| 107 |
+
# ── Step 4: Train priority predictor ────────────────────────
|
| 108 |
+
step(4, "Train Priority Predictor (GBR)")
|
| 109 |
+
from src.train_priority import train_priority_model
|
| 110 |
+
t = time.time()
|
| 111 |
+
gbr = train_priority_model()
|
| 112 |
+
logger.info("Priority training done in %.1fs", time.time() - t)
|
| 113 |
+
print(" ✓ Priority GBR trained")
|
| 114 |
+
|
| 115 |
+
# ── Step 5: Benchmark evaluation ─────────────────────────────────
|
| 116 |
+
if not args.no_eval:
|
| 117 |
+
step(5, "Benchmark Evaluation")
|
| 118 |
+
from src.evaluator import run_full_evaluation
|
| 119 |
+
t = time.time()
|
| 120 |
+
eval_seeds = list(range(99000, 99000 + n_eval_seeds))
|
| 121 |
+
results = run_full_evaluation(seeds=eval_seeds, n_workers=n_workers)
|
| 122 |
+
logger.info("Evaluation done: %d seeds in %.1fs", n_eval_seeds, time.time() - t)
|
| 123 |
+
print(f" ✓ Evaluation complete ({n_eval_seeds} seeds)")
|
| 124 |
+
|
| 125 |
+
# Print summary
|
| 126 |
+
bench_df = results["benchmark"]
|
| 127 |
+
if not bench_df.empty:
|
| 128 |
+
print("\n Performance Summary (mean total tardiness):")
|
| 129 |
+
for method in sorted(bench_df["method"].unique()):
|
| 130 |
+
mean_t = bench_df[bench_df["method"] == method]["total_tardiness"].mean()
|
| 131 |
+
print(f" {method:<20}: {mean_t:>8.1f}")
|
| 132 |
+
|
| 133 |
+
elapsed = time.time() - t_start
|
| 134 |
+
print(f"\n Pipeline complete in {elapsed / 60:.1f} minutes.")
|
| 135 |
+
print(f" Artifacts saved to: {ROOT / 'models'} and {ROOT / 'results'}\n")
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
if __name__ == "__main__":
|
| 139 |
+
main()
|
scripts/run_preset_benchmark.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
scripts/run_preset_benchmark.py — Per-preset 3-arm benchmark.
|
| 4 |
+
|
| 5 |
+
For each preset in src/presets.py, run THREE simulations on the preset's seed:
|
| 6 |
+
1. Baseline = preset.favored_heuristic (the home-turf specialist)
|
| 7 |
+
2. DAHS-Priority = priority GBR (single fixed model) (one learned ranker)
|
| 8 |
+
3. Meta-selector = BatchwiseSelector + xgb model (the actual product)
|
| 9 |
+
|
| 10 |
+
The 3-arm view honestly addresses No-Free-Lunch:
|
| 11 |
+
- DAHS-Priority is allowed to lose to a hand-tuned specialist on its own preset.
|
| 12 |
+
- The Meta-selector is the actual product — it should match or beat the
|
| 13 |
+
specialist by switching to that heuristic when conditions match.
|
| 14 |
+
|
| 15 |
+
Write results/preset_benchmark.json — consumed by the Simulation page's
|
| 16 |
+
"3-arm preset benchmark" panel.
|
| 17 |
+
|
| 18 |
+
Usage:
|
| 19 |
+
python scripts/run_preset_benchmark.py
|
| 20 |
+
"""
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
import json
|
| 24 |
+
import logging
|
| 25 |
+
import sys
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
from typing import Any, Dict, List
|
| 28 |
+
|
| 29 |
+
import joblib
|
| 30 |
+
import numpy as np
|
| 31 |
+
|
| 32 |
+
ROOT = Path(__file__).parent.parent
|
| 33 |
+
sys.path.insert(0, str(ROOT))
|
| 34 |
+
|
| 35 |
+
from src.simulator import WarehouseSimulator
|
| 36 |
+
from src.features import FeatureExtractor
|
| 37 |
+
from src.heuristics import (
|
| 38 |
+
fifo_dispatch, priority_edd_dispatch, critical_ratio_dispatch,
|
| 39 |
+
atc_dispatch, wspt_dispatch, slack_dispatch,
|
| 40 |
+
)
|
| 41 |
+
from src.presets import PRESETS
|
| 42 |
+
|
| 43 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
| 44 |
+
logger = logging.getLogger(__name__)
|
| 45 |
+
|
| 46 |
+
DISPATCH_FNS = {
|
| 47 |
+
"fifo": fifo_dispatch,
|
| 48 |
+
"priority_edd": priority_edd_dispatch,
|
| 49 |
+
"critical_ratio": critical_ratio_dispatch,
|
| 50 |
+
"atc": atc_dispatch,
|
| 51 |
+
"wspt": wspt_dispatch,
|
| 52 |
+
"slack": slack_dispatch,
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _make_priority_dispatch(model, fe: FeatureExtractor, sim_ref: list):
|
| 57 |
+
"""Closure: priority-GBR dispatcher that scores jobs per call."""
|
| 58 |
+
def dispatch(jobs, t, zone_id):
|
| 59 |
+
sim = sim_ref[0]
|
| 60 |
+
if not jobs or sim is None:
|
| 61 |
+
return fifo_dispatch(jobs, t, zone_id)
|
| 62 |
+
try:
|
| 63 |
+
state = sim.get_state_snapshot()
|
| 64 |
+
sf = fe.extract_scenario_features(state)
|
| 65 |
+
feats = np.stack([
|
| 66 |
+
np.concatenate([sf, fe.extract_job_features(j, state)])
|
| 67 |
+
for j in jobs
|
| 68 |
+
])
|
| 69 |
+
scores = model.predict(feats)
|
| 70 |
+
return [j for _, j in sorted(zip(scores, jobs),
|
| 71 |
+
key=lambda x: x[0], reverse=True)]
|
| 72 |
+
except Exception as exc:
|
| 73 |
+
logger.warning("priority dispatch fallback (%s)", exc)
|
| 74 |
+
return fifo_dispatch(jobs, t, zone_id)
|
| 75 |
+
return dispatch
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _preset_kwargs(p) -> Dict[str, Any]:
|
| 79 |
+
return dict(
|
| 80 |
+
base_arrival_rate=p.base_arrival_rate,
|
| 81 |
+
breakdown_prob=p.breakdown_prob,
|
| 82 |
+
batch_arrival_size=p.batch_arrival_size,
|
| 83 |
+
lunch_penalty_factor=p.lunch_penalty_factor,
|
| 84 |
+
job_type_frequencies=p.job_type_frequencies,
|
| 85 |
+
due_date_tightness=p.due_date_tightness,
|
| 86 |
+
processing_time_scale=p.processing_time_scale,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _make_meta_dispatch(selector, sim_ref: list):
|
| 91 |
+
"""Closure: BatchwiseSelector dispatcher that re-evaluates state per call."""
|
| 92 |
+
def dispatch(jobs, t, zone_id):
|
| 93 |
+
sim = sim_ref[0]
|
| 94 |
+
if sim is None:
|
| 95 |
+
return fifo_dispatch(jobs, t, zone_id)
|
| 96 |
+
try:
|
| 97 |
+
selector.update_state(sim.get_state_snapshot())
|
| 98 |
+
return selector.dispatch(jobs, t, zone_id)
|
| 99 |
+
except Exception as exc:
|
| 100 |
+
logger.warning("meta dispatch fallback (%s)", exc)
|
| 101 |
+
return fifo_dispatch(jobs, t, zone_id)
|
| 102 |
+
return dispatch
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def run_preset(p, gbr_model, xgb_model) -> Dict[str, Any]:
|
| 106 |
+
"""Run all three arms on one preset and return a row dict."""
|
| 107 |
+
from src.hybrid_scheduler import BatchwiseSelector
|
| 108 |
+
|
| 109 |
+
sim_kw = _preset_kwargs(p)
|
| 110 |
+
|
| 111 |
+
# ── Arm 1: Baseline (favored heuristic) ─────────────────────────────────
|
| 112 |
+
fe1 = FeatureExtractor()
|
| 113 |
+
base_fn = DISPATCH_FNS.get(p.favored_heuristic, fifo_dispatch)
|
| 114 |
+
base_sim = WarehouseSimulator(seed=p.seed, heuristic_fn=base_fn,
|
| 115 |
+
feature_extractor=fe1, **sim_kw)
|
| 116 |
+
base_metrics = base_sim.run(duration=600.0)
|
| 117 |
+
|
| 118 |
+
# ── Arm 2: DAHS-Priority (single fixed GBR) ─────────────────────────────
|
| 119 |
+
fe2 = FeatureExtractor()
|
| 120 |
+
sim_ref2: list = [None]
|
| 121 |
+
dispatch2 = _make_priority_dispatch(gbr_model, fe2, sim_ref2)
|
| 122 |
+
dahs_sim = WarehouseSimulator(seed=p.seed, heuristic_fn=dispatch2,
|
| 123 |
+
feature_extractor=fe2, **sim_kw)
|
| 124 |
+
sim_ref2[0] = dahs_sim
|
| 125 |
+
dahs_metrics = dahs_sim.run(duration=600.0)
|
| 126 |
+
|
| 127 |
+
# ── Arm 3: Meta-selector (BatchwiseSelector with xgb) ───────────────────
|
| 128 |
+
fe3 = FeatureExtractor()
|
| 129 |
+
selector = BatchwiseSelector(model=xgb_model, feature_extractor=fe3)
|
| 130 |
+
sim_ref3: list = [None]
|
| 131 |
+
dispatch3 = _make_meta_dispatch(selector, sim_ref3)
|
| 132 |
+
meta_sim = WarehouseSimulator(seed=p.seed, heuristic_fn=dispatch3,
|
| 133 |
+
feature_extractor=fe3, **sim_kw)
|
| 134 |
+
sim_ref3[0] = meta_sim
|
| 135 |
+
meta_metrics = meta_sim.run(duration=600.0)
|
| 136 |
+
|
| 137 |
+
base_t = float(base_metrics.total_tardiness)
|
| 138 |
+
dahs_t = float(dahs_metrics.total_tardiness)
|
| 139 |
+
meta_t = float(meta_metrics.total_tardiness)
|
| 140 |
+
dahs_imp = (base_t - dahs_t) / base_t * 100.0 if base_t > 0 else 0.0
|
| 141 |
+
meta_imp = (base_t - meta_t) / base_t * 100.0 if base_t > 0 else 0.0
|
| 142 |
+
|
| 143 |
+
# Snapshot which heuristics the meta-selector actually picked
|
| 144 |
+
sw_log = selector.switching_log.entries if selector.switching_log else []
|
| 145 |
+
picks = {}
|
| 146 |
+
for entry in sw_log:
|
| 147 |
+
h = entry.get("selected", "?")
|
| 148 |
+
picks[h] = picks.get(h, 0) + 1
|
| 149 |
+
top_picks = sorted(picks.items(), key=lambda x: x[1], reverse=True)[:3]
|
| 150 |
+
|
| 151 |
+
return {
|
| 152 |
+
"preset": p.name,
|
| 153 |
+
"favored": p.favored_heuristic,
|
| 154 |
+
"seed": int(p.seed),
|
| 155 |
+
"baseline_tardiness": round(base_t, 2),
|
| 156 |
+
"dahs_tardiness": round(dahs_t, 2),
|
| 157 |
+
"meta_tardiness": round(meta_t, 2),
|
| 158 |
+
"baseline_sla_breach": round(float(base_metrics.sla_breach_rate), 4),
|
| 159 |
+
"dahs_sla_breach": round(float(dahs_metrics.sla_breach_rate), 4),
|
| 160 |
+
"meta_sla_breach": round(float(meta_metrics.sla_breach_rate), 4),
|
| 161 |
+
"baseline_completed": int(base_metrics.completed_jobs),
|
| 162 |
+
"dahs_completed": int(dahs_metrics.completed_jobs),
|
| 163 |
+
"meta_completed": int(meta_metrics.completed_jobs),
|
| 164 |
+
"improvement_pct": round(dahs_imp, 2), # back-compat: DAHS-Priority vs baseline
|
| 165 |
+
"meta_improvement_pct": round(meta_imp, 2), # meta-selector vs baseline
|
| 166 |
+
"dahs_wins": dahs_t <= base_t,
|
| 167 |
+
"meta_wins": meta_t <= base_t,
|
| 168 |
+
"meta_top_picks": top_picks, # what did the selector actually pick?
|
| 169 |
+
"meta_n_switches": len(sw_log),
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def main() -> None:
|
| 174 |
+
gbr_path = ROOT / "models" / "priority_gbr.joblib"
|
| 175 |
+
xgb_path = ROOT / "models" / "selector_xgb.joblib"
|
| 176 |
+
if not gbr_path.exists():
|
| 177 |
+
raise SystemExit(f"Missing model: {gbr_path}. Run scripts/run_pipeline.py first.")
|
| 178 |
+
if not xgb_path.exists():
|
| 179 |
+
raise SystemExit(f"Missing model: {xgb_path}. Run scripts/run_pipeline.py first.")
|
| 180 |
+
|
| 181 |
+
logger.info("Loading priority GBR from %s", gbr_path)
|
| 182 |
+
gbr_model = joblib.load(gbr_path)
|
| 183 |
+
logger.info("Loading selector XGB from %s", xgb_path)
|
| 184 |
+
xgb_model = joblib.load(xgb_path)
|
| 185 |
+
|
| 186 |
+
rows: List[Dict[str, Any]] = []
|
| 187 |
+
for p in PRESETS:
|
| 188 |
+
logger.info("Running preset %s (favored=%s, seed=%d)",
|
| 189 |
+
p.name, p.favored_heuristic, p.seed)
|
| 190 |
+
rows.append(run_preset(p, gbr_model, xgb_model))
|
| 191 |
+
|
| 192 |
+
out_path = ROOT / "results" / "preset_benchmark.json"
|
| 193 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 194 |
+
out_path.write_text(json.dumps(rows, indent=2))
|
| 195 |
+
logger.info("Wrote %s", out_path)
|
| 196 |
+
|
| 197 |
+
print("\n" + "=" * 110)
|
| 198 |
+
print(f"{'Preset':<22} {'Favored':<14} {'Baseline':>10} {'DAHS-Pri':>10} {'Meta-sel':>10} "
|
| 199 |
+
f"{'DAHSwin':>8} {'Metawin':>8}")
|
| 200 |
+
print("-" * 110)
|
| 201 |
+
n_dahs = 0
|
| 202 |
+
n_meta = 0
|
| 203 |
+
for r in rows:
|
| 204 |
+
if r["dahs_wins"]: n_dahs += 1
|
| 205 |
+
if r["meta_wins"]: n_meta += 1
|
| 206 |
+
print(f"{r['preset']:<22} {r['favored']:<14} "
|
| 207 |
+
f"{r['baseline_tardiness']:>10.1f} {r['dahs_tardiness']:>10.1f} {r['meta_tardiness']:>10.1f} "
|
| 208 |
+
f"{('YES' if r['dahs_wins'] else 'NO'):>8} {('YES' if r['meta_wins'] else 'NO'):>8}")
|
| 209 |
+
print("=" * 110)
|
| 210 |
+
print(f"DAHS-Priority wins: {n_dahs}/{len(rows)} Meta-selector wins: {n_meta}/{len(rows)}\n")
|
| 211 |
+
|
| 212 |
+
print("Meta-selector heuristic picks per preset:")
|
| 213 |
+
for r in rows:
|
| 214 |
+
picks = r.get("meta_top_picks", [])
|
| 215 |
+
picks_str = ", ".join(f"{h}:{n}" for h, n in picks)
|
| 216 |
+
print(f" {r['preset']:<22} switches={r['meta_n_switches']:<3} top_picks=[{picks_str}]")
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
if __name__ == "__main__":
|
| 220 |
+
main()
|
src/__init__.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
__init__.py — Public API for DAHS_2 src package
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from src.simulator import (
|
| 6 |
+
WarehouseSimulator,
|
| 7 |
+
SimulationMetrics,
|
| 8 |
+
Job,
|
| 9 |
+
Operation,
|
| 10 |
+
StationState,
|
| 11 |
+
ZoneConfig,
|
| 12 |
+
JobType,
|
| 13 |
+
)
|
| 14 |
+
from src.features import (
|
| 15 |
+
FeatureExtractor,
|
| 16 |
+
SCENARIO_FEATURE_NAMES,
|
| 17 |
+
JOB_FEATURE_NAMES,
|
| 18 |
+
FEATURE_DESCRIPTIONS,
|
| 19 |
+
)
|
| 20 |
+
from src.heuristics import (
|
| 21 |
+
fifo_dispatch,
|
| 22 |
+
priority_edd_dispatch,
|
| 23 |
+
critical_ratio_dispatch,
|
| 24 |
+
atc_dispatch,
|
| 25 |
+
wspt_dispatch,
|
| 26 |
+
slack_dispatch,
|
| 27 |
+
DISPATCH_MAP,
|
| 28 |
+
ALL_HEURISTICS,
|
| 29 |
+
HEURISTIC_LABELS,
|
| 30 |
+
)
|
| 31 |
+
from src.hybrid_scheduler import (
|
| 32 |
+
BatchwiseSelector,
|
| 33 |
+
HybridPriority,
|
| 34 |
+
SwitchingLog,
|
| 35 |
+
load_batchwise_selector,
|
| 36 |
+
load_hybrid_priority,
|
| 37 |
+
)
|
| 38 |
+
from src.presets import (
|
| 39 |
+
PresetScenario,
|
| 40 |
+
PRESETS,
|
| 41 |
+
get_preset,
|
| 42 |
+
get_all_presets,
|
| 43 |
+
run_preset_demo,
|
| 44 |
+
run_all_preset_demos,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
__all__ = [
|
| 48 |
+
# Simulator
|
| 49 |
+
"WarehouseSimulator",
|
| 50 |
+
"SimulationMetrics",
|
| 51 |
+
"Job",
|
| 52 |
+
"Operation",
|
| 53 |
+
"StationState",
|
| 54 |
+
"ZoneConfig",
|
| 55 |
+
"JobType",
|
| 56 |
+
# Features
|
| 57 |
+
"FeatureExtractor",
|
| 58 |
+
"SCENARIO_FEATURE_NAMES",
|
| 59 |
+
"JOB_FEATURE_NAMES",
|
| 60 |
+
"FEATURE_DESCRIPTIONS",
|
| 61 |
+
# Heuristics
|
| 62 |
+
"fifo_dispatch",
|
| 63 |
+
"priority_edd_dispatch",
|
| 64 |
+
"critical_ratio_dispatch",
|
| 65 |
+
"atc_dispatch",
|
| 66 |
+
"wspt_dispatch",
|
| 67 |
+
"slack_dispatch",
|
| 68 |
+
"DISPATCH_MAP",
|
| 69 |
+
"ALL_HEURISTICS",
|
| 70 |
+
"HEURISTIC_LABELS",
|
| 71 |
+
# Hybrid scheduler
|
| 72 |
+
"BatchwiseSelector",
|
| 73 |
+
"HybridPriority",
|
| 74 |
+
"SwitchingLog",
|
| 75 |
+
"load_batchwise_selector",
|
| 76 |
+
"load_hybrid_priority",
|
| 77 |
+
# Presets
|
| 78 |
+
"PresetScenario",
|
| 79 |
+
"PRESETS",
|
| 80 |
+
"get_preset",
|
| 81 |
+
"get_all_presets",
|
| 82 |
+
"run_preset_demo",
|
| 83 |
+
"run_all_preset_demos",
|
| 84 |
+
]
|
src/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (1.4 kB). View file
|
|
|
src/__pycache__/data_generator.cpython-312.pyc
ADDED
|
Binary file (17.2 kB). View file
|
|
|
src/__pycache__/evaluator.cpython-312.pyc
ADDED
|
Binary file (45.4 kB). View file
|
|
|
src/__pycache__/features.cpython-312.pyc
ADDED
|
Binary file (19.6 kB). View file
|
|
|
src/__pycache__/heuristics.cpython-312.pyc
ADDED
|
Binary file (7.89 kB). View file
|
|
|
src/__pycache__/hybrid_scheduler.cpython-312.pyc
ADDED
|
Binary file (38.5 kB). View file
|
|
|
src/__pycache__/presets.cpython-312.pyc
ADDED
|
Binary file (15.1 kB). View file
|
|
|
src/__pycache__/references.cpython-312.pyc
ADDED
|
Binary file (4.94 kB). View file
|
|
|
src/__pycache__/simulator.cpython-312.pyc
ADDED
|
Binary file (65.3 kB). View file
|
|
|
src/__pycache__/train_priority.cpython-312.pyc
ADDED
|
Binary file (7 kB). View file
|
|
|
src/__pycache__/train_selector.cpython-312.pyc
ADDED
|
Binary file (13.5 kB). View file
|
|
|
src/data_generator.py
ADDED
|
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
data_generator.py — Training Data Generation for DAHS_2
|
| 3 |
+
|
| 4 |
+
NEW in DAHS_2: Snapshot-fork algorithm
|
| 5 |
+
Instead of running full simulations with each heuristic,
|
| 6 |
+
this generator takes snapshots every 10 minutes, forks 6 short
|
| 7 |
+
simulations (20 min each), and labels which heuristic wins per-window.
|
| 8 |
+
Result: ~60 rows per scenario instead of 1, with situation-level labels.
|
| 9 |
+
|
| 10 |
+
Also generates:
|
| 11 |
+
- priority_dataset.csv (same as DAHS_1)
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import logging
|
| 17 |
+
import multiprocessing as mp
|
| 18 |
+
import os
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import pandas as pd
|
| 24 |
+
from tqdm import tqdm
|
| 25 |
+
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
DATA_DIR = Path(__file__).parent.parent / "data" / "raw"
|
| 29 |
+
|
| 30 |
+
HEURISTIC_NAMES = [
|
| 31 |
+
"fifo",
|
| 32 |
+
"priority_edd",
|
| 33 |
+
"critical_ratio",
|
| 34 |
+
"atc",
|
| 35 |
+
"wspt",
|
| 36 |
+
"slack",
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
SNAPSHOT_INTERVAL = 15.0 # minutes between snapshots (matches BatchwiseSelector.EVAL_INTERVAL)
|
| 40 |
+
FORK_WINDOW = 60.0 # minutes per fork evaluation (covers express SLA window of 60 min)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# ---------------------------------------------------------------------------
|
| 44 |
+
# 7-region scenario diversity (ported from DAHS_1)
|
| 45 |
+
# ---------------------------------------------------------------------------
|
| 46 |
+
|
| 47 |
+
def _make_diverse_scenario_configs(n_scenarios: int, rng: np.random.Generator) -> List[Dict[str, Any]]:
|
| 48 |
+
"""Generate diverse simulator parameter configs to avoid class imbalance."""
|
| 49 |
+
configs: List[Dict[str, Any]] = []
|
| 50 |
+
|
| 51 |
+
regions = [
|
| 52 |
+
# FIFO-friendly: low load, uniform jobs, loose deadlines
|
| 53 |
+
{"arrival": (1.0, 2.0), "bkdown": (0.0, 0.001), "due": (1.8, 3.0),
|
| 54 |
+
"batch": (5, 15), "lunch": (1.0, 1.1), "pscale": (0.8, 1.2),
|
| 55 |
+
"mix": "uniform"},
|
| 56 |
+
# Priority-EDD: high express, tight deadlines
|
| 57 |
+
{"arrival": (2.0, 3.5), "bkdown": (0.0, 0.005), "due": (0.4, 0.8),
|
| 58 |
+
"batch": (15, 40), "lunch": (1.0, 1.3), "pscale": (0.8, 1.2),
|
| 59 |
+
"mix": "express_heavy"},
|
| 60 |
+
# Critical-Ratio: high breakdowns, heterogeneous pressure
|
| 61 |
+
{"arrival": (2.0, 3.0), "bkdown": (0.008, 0.020), "due": (0.6, 1.2),
|
| 62 |
+
"batch": (20, 50), "lunch": (1.2, 1.6), "pscale": (1.0, 1.5),
|
| 63 |
+
"mix": "diverse"},
|
| 64 |
+
# ATC: heavy load + surge, weighted tardiness matters
|
| 65 |
+
{"arrival": (3.0, 5.0), "bkdown": (0.001, 0.008), "due": (0.7, 1.1),
|
| 66 |
+
"batch": (30, 80), "lunch": (1.2, 1.5), "pscale": (0.9, 1.3),
|
| 67 |
+
"mix": "diverse"},
|
| 68 |
+
# WSPT: many short jobs, steady flow
|
| 69 |
+
{"arrival": (2.5, 4.0), "bkdown": (0.0, 0.003), "due": (1.0, 1.8),
|
| 70 |
+
"batch": (10, 30), "lunch": (1.0, 1.2), "pscale": (0.5, 0.9),
|
| 71 |
+
"mix": "short_heavy"},
|
| 72 |
+
# Slack: tight deadlines, recovery-mode
|
| 73 |
+
{"arrival": (2.5, 3.5), "bkdown": (0.003, 0.012), "due": (0.2, 0.5),
|
| 74 |
+
"batch": (20, 50), "lunch": (1.3, 1.8), "pscale": (1.0, 1.4),
|
| 75 |
+
"mix": "diverse"},
|
| 76 |
+
# Default / general
|
| 77 |
+
{"arrival": (1.5, 4.0), "bkdown": (0.0, 0.015), "due": (0.5, 2.0),
|
| 78 |
+
"batch": (10, 60), "lunch": (1.0, 1.5), "pscale": (0.7, 1.3),
|
| 79 |
+
"mix": "random"},
|
| 80 |
+
]
|
| 81 |
+
|
| 82 |
+
mix_templates = {
|
| 83 |
+
"uniform": {"A": 0.0, "B": 0.0, "C": 1.0, "D": 0.0, "E": 0.0},
|
| 84 |
+
"express_heavy": {"A": 0.20, "B": 0.10, "C": 0.10, "D": 0.10, "E": 0.50},
|
| 85 |
+
"short_heavy": {"A": 0.35, "B": 0.10, "C": 0.10, "D": 0.05, "E": 0.40},
|
| 86 |
+
"diverse": {"A": 0.25, "B": 0.25, "C": 0.20, "D": 0.15, "E": 0.15},
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
per_region = n_scenarios // len(regions)
|
| 90 |
+
remainder = n_scenarios - per_region * len(regions)
|
| 91 |
+
|
| 92 |
+
seed_counter = 0
|
| 93 |
+
for ri, region in enumerate(regions):
|
| 94 |
+
count = per_region + (1 if ri < remainder else 0)
|
| 95 |
+
for _ in range(count):
|
| 96 |
+
ar = rng.uniform(*region["arrival"])
|
| 97 |
+
bk = rng.uniform(*region["bkdown"])
|
| 98 |
+
dd = rng.uniform(*region["due"])
|
| 99 |
+
bat = int(rng.uniform(*region["batch"]))
|
| 100 |
+
lp = rng.uniform(*region["lunch"])
|
| 101 |
+
ps = rng.uniform(*region["pscale"])
|
| 102 |
+
|
| 103 |
+
if region["mix"] == "random":
|
| 104 |
+
freqs_raw = rng.dirichlet([1, 1, 1, 1, 1])
|
| 105 |
+
jt_freq = {k: float(v) for k, v in zip("ABCDE", freqs_raw)}
|
| 106 |
+
elif region["mix"] in mix_templates:
|
| 107 |
+
base = mix_templates[region["mix"]].copy()
|
| 108 |
+
noise = rng.uniform(-0.05, 0.05, 5)
|
| 109 |
+
vals = np.array([base[k] for k in "ABCDE"]) + noise
|
| 110 |
+
vals = np.clip(vals, 0.01, None)
|
| 111 |
+
vals /= vals.sum()
|
| 112 |
+
jt_freq = {k: float(v) for k, v in zip("ABCDE", vals)}
|
| 113 |
+
else:
|
| 114 |
+
jt_freq = {}
|
| 115 |
+
|
| 116 |
+
configs.append({
|
| 117 |
+
"seed": seed_counter,
|
| 118 |
+
"base_arrival_rate": round(ar, 2),
|
| 119 |
+
"breakdown_prob": round(bk, 4),
|
| 120 |
+
"batch_arrival_size": bat,
|
| 121 |
+
"lunch_penalty_factor": round(lp, 2),
|
| 122 |
+
"job_type_frequencies": jt_freq,
|
| 123 |
+
"due_date_tightness": round(dd, 2),
|
| 124 |
+
"processing_time_scale": round(ps, 2),
|
| 125 |
+
})
|
| 126 |
+
seed_counter += 1
|
| 127 |
+
|
| 128 |
+
return configs
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
# ---------------------------------------------------------------------------
|
| 132 |
+
# NEW: Snapshot-fork worker (top-level for multiprocessing)
|
| 133 |
+
# ---------------------------------------------------------------------------
|
| 134 |
+
|
| 135 |
+
def _run_snapshot_scenario(args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 136 |
+
"""Worker: run one full scenario with snapshot-fork labeling.
|
| 137 |
+
|
| 138 |
+
Algorithm:
|
| 139 |
+
1. Run base sim (FIFO) to each 10-minute snapshot
|
| 140 |
+
2. At each snapshot, save state and fork 6 heuristics 20 min each
|
| 141 |
+
3. Label the snapshot with the best-performing heuristic
|
| 142 |
+
Returns ~60 rows per scenario.
|
| 143 |
+
"""
|
| 144 |
+
config = args
|
| 145 |
+
from src.heuristics import (
|
| 146 |
+
fifo_dispatch, priority_edd_dispatch, critical_ratio_dispatch,
|
| 147 |
+
atc_dispatch, wspt_dispatch, slack_dispatch, DISPATCH_MAP,
|
| 148 |
+
)
|
| 149 |
+
from src.simulator import WarehouseSimulator
|
| 150 |
+
from src.features import FeatureExtractor, SCENARIO_FEATURE_NAMES
|
| 151 |
+
|
| 152 |
+
sim_kw = {
|
| 153 |
+
"base_arrival_rate": config.get("base_arrival_rate", 2.5),
|
| 154 |
+
"breakdown_prob": config.get("breakdown_prob", 0.003),
|
| 155 |
+
"batch_arrival_size": config.get("batch_arrival_size", 30),
|
| 156 |
+
"lunch_penalty_factor": config.get("lunch_penalty_factor", 1.3),
|
| 157 |
+
"job_type_frequencies": config.get("job_type_frequencies", {}),
|
| 158 |
+
"due_date_tightness": config.get("due_date_tightness", 1.0),
|
| 159 |
+
"processing_time_scale": config.get("processing_time_scale", 1.0),
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
seed = config["seed"]
|
| 163 |
+
fe = FeatureExtractor()
|
| 164 |
+
sim = WarehouseSimulator(seed=seed, heuristic_fn=fifo_dispatch, feature_extractor=fe, **sim_kw)
|
| 165 |
+
sim.init()
|
| 166 |
+
|
| 167 |
+
rows = []
|
| 168 |
+
SIM_DURATION = 600.0
|
| 169 |
+
|
| 170 |
+
for t in np.arange(SNAPSHOT_INTERVAL, SIM_DURATION, SNAPSHOT_INTERVAL):
|
| 171 |
+
t = float(t)
|
| 172 |
+
sim.step_to(t)
|
| 173 |
+
state_snap = sim.get_state_snapshot()
|
| 174 |
+
|
| 175 |
+
# Extract 22 scenario features from current state
|
| 176 |
+
features = fe.extract_scenario_features(state_snap)
|
| 177 |
+
if np.any(~np.isfinite(features)):
|
| 178 |
+
continue # skip bad windows
|
| 179 |
+
|
| 180 |
+
# Save state for forking
|
| 181 |
+
saved_state = sim.save_state()
|
| 182 |
+
|
| 183 |
+
# Fork 6 heuristics for FORK_WINDOW min each, collect raw metrics
|
| 184 |
+
fork_end = t + FORK_WINDOW
|
| 185 |
+
raw_metrics: List[Tuple[float, float, float]] = []
|
| 186 |
+
for heur_name in HEURISTIC_NAMES:
|
| 187 |
+
try:
|
| 188 |
+
heur_fn = DISPATCH_MAP[heur_name]
|
| 189 |
+
fork = WarehouseSimulator.from_state(saved_state, heur_fn)
|
| 190 |
+
fork.step_to(fork_end)
|
| 191 |
+
metrics = fork.get_partial_metrics(since_time=t)
|
| 192 |
+
tard = metrics.total_tardiness if np.isfinite(metrics.total_tardiness) else 1e9
|
| 193 |
+
sla = metrics.sla_breach_rate if np.isfinite(metrics.sla_breach_rate) else 1.0
|
| 194 |
+
cyc = metrics.avg_cycle_time if np.isfinite(metrics.avg_cycle_time) else 1e6
|
| 195 |
+
except Exception:
|
| 196 |
+
tard, sla, cyc = 1e9, 1.0, 1e6
|
| 197 |
+
raw_metrics.append((tard, sla, cyc))
|
| 198 |
+
|
| 199 |
+
# Normalize each metric across the 6 heuristics so units are comparable.
|
| 200 |
+
# Without this, raw tardiness (hundreds-thousands) dominates SLA (0-1) and
|
| 201 |
+
# cycle time (tens), so WSPT gets labeled at almost every snapshot.
|
| 202 |
+
arr = np.array(raw_metrics, dtype=float)
|
| 203 |
+
def _norm(col: np.ndarray) -> np.ndarray:
|
| 204 |
+
lo, hi = float(col.min()), float(col.max())
|
| 205 |
+
if hi - lo < 1e-10:
|
| 206 |
+
return np.zeros_like(col)
|
| 207 |
+
return (col - lo) / (hi - lo)
|
| 208 |
+
n_tard = _norm(arr[:, 0])
|
| 209 |
+
n_sla = _norm(arr[:, 1])
|
| 210 |
+
n_cyc = _norm(arr[:, 2])
|
| 211 |
+
# Weights match the benchmark objective (tardiness-dominant) to avoid
|
| 212 |
+
# cycle-time over-weighting which biased labels toward WSPT.
|
| 213 |
+
scores_arr = 0.55 * n_tard + 0.35 * n_sla + 0.10 * n_cyc
|
| 214 |
+
|
| 215 |
+
# Label: best heuristic for THIS situation (lowest normalized composite).
|
| 216 |
+
# Tie-break: when the top two are within TIE_EPS, break ties by the
|
| 217 |
+
# heuristic that currently has the lower global label frequency.
|
| 218 |
+
# This prevents any rule collapsing the dataset (WSPT dominance).
|
| 219 |
+
TIE_EPS = 0.02
|
| 220 |
+
order = np.argsort(scores_arr)
|
| 221 |
+
best = int(order[0])
|
| 222 |
+
runner = int(order[1]) if len(order) > 1 else best
|
| 223 |
+
if abs(scores_arr[best] - scores_arr[runner]) < TIE_EPS:
|
| 224 |
+
# Use rarity-of-label heuristic: among tied candidates, prefer the one
|
| 225 |
+
# with lower ordinal frequency (approximated by reverse index order —
|
| 226 |
+
# FIFO=0, EDD=1, CR=2, ATC=3, WSPT=4, Slack=5; non-WSPT preferred
|
| 227 |
+
# when roughly equal).
|
| 228 |
+
tied = [int(i) for i in order if scores_arr[i] - scores_arr[best] < TIE_EPS]
|
| 229 |
+
# Prefer the tied heuristic furthest from WSPT (index 4) to diversify
|
| 230 |
+
tied.sort(key=lambda h: abs(h - 4), reverse=True)
|
| 231 |
+
best = tied[0]
|
| 232 |
+
label = best
|
| 233 |
+
scores = scores_arr.tolist()
|
| 234 |
+
|
| 235 |
+
row = {name: float(val) for name, val in zip(SCENARIO_FEATURE_NAMES, features)}
|
| 236 |
+
row["label"] = label
|
| 237 |
+
rows.append(row)
|
| 238 |
+
|
| 239 |
+
return rows
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def _composite_score(metrics) -> float:
|
| 243 |
+
"""Scoring formula: 0.40*tardiness + 0.35*sla + 0.25*cycle_time (normalized)."""
|
| 244 |
+
# Raw (unnormalized) — normalization happens across heuristics in the caller
|
| 245 |
+
tard = metrics.total_tardiness if metrics.total_tardiness != float("inf") else 1e9
|
| 246 |
+
sla = metrics.sla_breach_rate if metrics.sla_breach_rate != float("inf") else 1.0
|
| 247 |
+
cyc = metrics.avg_cycle_time if metrics.avg_cycle_time != float("inf") else 1e6
|
| 248 |
+
return 0.40 * tard + 0.35 * sla * 1000 + 0.25 * cyc
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
# ---------------------------------------------------------------------------
|
| 252 |
+
# Priority dataset worker (ported from DAHS_1)
|
| 253 |
+
# ---------------------------------------------------------------------------
|
| 254 |
+
|
| 255 |
+
def _run_priority_scenario(args: Tuple[int, int]) -> List[Dict[str, Any]]:
|
| 256 |
+
"""Worker: run one seed with ATC baseline, collect job-level feature rows."""
|
| 257 |
+
seed, n_points = args
|
| 258 |
+
from src.heuristics import atc_dispatch
|
| 259 |
+
from src.simulator import WarehouseSimulator
|
| 260 |
+
from src.features import FeatureExtractor
|
| 261 |
+
|
| 262 |
+
_PRIO_W = {"A": 2.0, "B": 1.5, "C": 1.0, "D": 0.8, "E": 3.0}
|
| 263 |
+
_DD_OFFSET = {"A": 120, "B": 160, "C": 240, "D": 320, "E": 60}
|
| 264 |
+
|
| 265 |
+
fe = FeatureExtractor()
|
| 266 |
+
sim = WarehouseSimulator(seed=seed, heuristic_fn=atc_dispatch, feature_extractor=fe)
|
| 267 |
+
sim.run(duration=600.0)
|
| 268 |
+
|
| 269 |
+
rows: List[Dict[str, Any]] = []
|
| 270 |
+
state = sim.get_state_snapshot()
|
| 271 |
+
completed = sim.completed_jobs
|
| 272 |
+
|
| 273 |
+
if not completed:
|
| 274 |
+
return rows
|
| 275 |
+
|
| 276 |
+
rng = np.random.default_rng(seed)
|
| 277 |
+
sampled = rng.choice(len(completed), size=min(n_points, len(completed)), replace=False)
|
| 278 |
+
|
| 279 |
+
for idx in sampled:
|
| 280 |
+
job = completed[int(idx)]
|
| 281 |
+
scenario_feats = fe.extract_scenario_features(state)
|
| 282 |
+
job_feats = fe.extract_job_features(job, state)
|
| 283 |
+
|
| 284 |
+
w = _PRIO_W.get(job.job_type, 1.0)
|
| 285 |
+
dd_off = _DD_OFFSET.get(job.job_type, 120)
|
| 286 |
+
cycle_time = job.completion_time - job.arrival_time
|
| 287 |
+
tardiness = max(0.0, job.completion_time - job.due_date)
|
| 288 |
+
|
| 289 |
+
remaining = job.remaining_proc_time()
|
| 290 |
+
time_to_due = job.due_date - state["current_time"]
|
| 291 |
+
urgency = 1.0 - min(1.0, max(0.0, time_to_due / max(dd_off, 1.0)))
|
| 292 |
+
importance = w / 3.0
|
| 293 |
+
efficiency = 1.0 / (1.0 + remaining / 30.0)
|
| 294 |
+
delivery_perf = max(0.0, 1.0 - tardiness / max(dd_off, 1.0))
|
| 295 |
+
|
| 296 |
+
priority_score = float(
|
| 297 |
+
0.30 * urgency
|
| 298 |
+
+ 0.25 * importance
|
| 299 |
+
+ 0.20 * efficiency
|
| 300 |
+
+ 0.25 * delivery_perf
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
if not np.isfinite(priority_score):
|
| 304 |
+
continue
|
| 305 |
+
|
| 306 |
+
row = {
|
| 307 |
+
**{f"sf_{i}": float(v) for i, v in enumerate(scenario_feats)},
|
| 308 |
+
**{f"jf_{i}": float(v) for i, v in enumerate(job_feats)},
|
| 309 |
+
"priority_score": priority_score,
|
| 310 |
+
}
|
| 311 |
+
rows.append(row)
|
| 312 |
+
|
| 313 |
+
return rows
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
# ---------------------------------------------------------------------------
|
| 317 |
+
# Dataset generators
|
| 318 |
+
# ---------------------------------------------------------------------------
|
| 319 |
+
|
| 320 |
+
def generate_selector_dataset(
|
| 321 |
+
n_scenarios: int = 1000,
|
| 322 |
+
n_workers: int = 4,
|
| 323 |
+
save: bool = True,
|
| 324 |
+
) -> pd.DataFrame:
|
| 325 |
+
"""Generate the heuristic selector training dataset using snapshot-fork algorithm.
|
| 326 |
+
|
| 327 |
+
Parameters
|
| 328 |
+
----------
|
| 329 |
+
n_scenarios : int
|
| 330 |
+
Number of scenario seeds to simulate.
|
| 331 |
+
n_workers : int
|
| 332 |
+
Number of parallel worker processes.
|
| 333 |
+
save : bool
|
| 334 |
+
Whether to save the CSV to data/raw/.
|
| 335 |
+
|
| 336 |
+
Returns
|
| 337 |
+
-------
|
| 338 |
+
pd.DataFrame
|
| 339 |
+
22 scenario feature columns + "label" (0-5, one per heuristic).
|
| 340 |
+
~60 rows per scenario (one per 10-min snapshot).
|
| 341 |
+
"""
|
| 342 |
+
from src.features import SCENARIO_FEATURE_NAMES
|
| 343 |
+
|
| 344 |
+
master_rng = np.random.default_rng(777)
|
| 345 |
+
configs = _make_diverse_scenario_configs(n_scenarios, master_rng)
|
| 346 |
+
|
| 347 |
+
logger.info(
|
| 348 |
+
"Generating selector dataset (snapshot-fork): %d scenarios × ~60 snapshots each",
|
| 349 |
+
n_scenarios
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
all_rows: List[Dict[str, Any]] = []
|
| 353 |
+
ctx = mp.get_context("spawn")
|
| 354 |
+
with ctx.Pool(processes=n_workers) as pool:
|
| 355 |
+
for rows in tqdm(
|
| 356 |
+
pool.imap_unordered(_run_snapshot_scenario, configs),
|
| 357 |
+
total=len(configs),
|
| 358 |
+
desc="Snapshot-fork data gen",
|
| 359 |
+
):
|
| 360 |
+
all_rows.extend(rows)
|
| 361 |
+
|
| 362 |
+
df = pd.DataFrame(all_rows)
|
| 363 |
+
|
| 364 |
+
# Sanitize
|
| 365 |
+
df = df.replace([np.inf, -np.inf], np.nan).fillna(0.0)
|
| 366 |
+
|
| 367 |
+
logger.info("Selector dataset shape: %s", df.shape)
|
| 368 |
+
if "label" in df.columns:
|
| 369 |
+
label_counts = df["label"].value_counts().to_dict()
|
| 370 |
+
logger.info("Label distribution: %s", label_counts)
|
| 371 |
+
|
| 372 |
+
if save:
|
| 373 |
+
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
| 374 |
+
path = DATA_DIR / "selector_dataset.csv"
|
| 375 |
+
df.to_csv(path, index=False)
|
| 376 |
+
logger.info("Saved selector dataset -> %s", path)
|
| 377 |
+
|
| 378 |
+
return df
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def generate_priority_dataset(
|
| 382 |
+
n_scenarios: int = 5_000,
|
| 383 |
+
n_points_per: int = 10,
|
| 384 |
+
n_workers: int = 4,
|
| 385 |
+
save: bool = True,
|
| 386 |
+
) -> pd.DataFrame:
|
| 387 |
+
"""Generate the priority predictor training dataset (ported from DAHS_1)."""
|
| 388 |
+
from src.features import SCENARIO_FEATURE_NAMES, JOB_FEATURE_NAMES
|
| 389 |
+
|
| 390 |
+
seeds = list(range(20_000, 20_000 + n_scenarios))
|
| 391 |
+
all_args = [(seed, n_points_per) for seed in seeds]
|
| 392 |
+
|
| 393 |
+
logger.info("Generating priority dataset: %d scenarios × %d points", n_scenarios, n_points_per)
|
| 394 |
+
|
| 395 |
+
all_rows: List[Dict] = []
|
| 396 |
+
ctx = mp.get_context("spawn")
|
| 397 |
+
with ctx.Pool(processes=n_workers) as pool:
|
| 398 |
+
for batch in tqdm(
|
| 399 |
+
pool.imap_unordered(_run_priority_scenario, all_args),
|
| 400 |
+
total=len(all_args),
|
| 401 |
+
desc="Priority data gen",
|
| 402 |
+
):
|
| 403 |
+
all_rows.extend(batch)
|
| 404 |
+
|
| 405 |
+
df = pd.DataFrame(all_rows)
|
| 406 |
+
df = df.replace([np.inf, -np.inf], np.nan).dropna()
|
| 407 |
+
|
| 408 |
+
sf_names = {f"sf_{i}": name for i, name in enumerate(SCENARIO_FEATURE_NAMES)}
|
| 409 |
+
jf_names = {f"jf_{i}": name for i, name in enumerate(JOB_FEATURE_NAMES)}
|
| 410 |
+
df.rename(columns={**sf_names, **jf_names}, inplace=True)
|
| 411 |
+
|
| 412 |
+
logger.info("Priority dataset shape: %s", df.shape)
|
| 413 |
+
|
| 414 |
+
if save:
|
| 415 |
+
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
| 416 |
+
path = DATA_DIR / "priority_dataset.csv"
|
| 417 |
+
df.to_csv(path, index=False)
|
| 418 |
+
logger.info("Saved priority dataset -> %s", path)
|
| 419 |
+
|
| 420 |
+
return df
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
if __name__ == "__main__":
|
| 424 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
| 425 |
+
generate_selector_dataset(n_scenarios=50, n_workers=2)
|
src/evaluator.py
ADDED
|
@@ -0,0 +1,899 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
evaluator.py — Benchmark & Statistical Analysis Pipeline (DAHS_2)
|
| 3 |
+
|
| 4 |
+
Port from DAHS_1 evaluator.py + extensions:
|
| 5 |
+
- 300 test seeds (99000-99299) × 9 methods
|
| 6 |
+
- Statistical tests: Friedman, Nemenyi, Wilcoxon, Cohen's d, Bootstrap CI
|
| 7 |
+
- NEW: Switching analysis (evaluations, switches, hysteresis rate, distribution)
|
| 8 |
+
- NEW: JSON export for frontend Results page
|
| 9 |
+
- 11 dark-theme plots
|
| 10 |
+
|
| 11 |
+
Statistical Methodology References
|
| 12 |
+
-----------------------------------
|
| 13 |
+
- Friedman non-parametric test for k ≥ 3 related samples:
|
| 14 |
+
Friedman, M. (1940). A comparison of alternative tests of significance
|
| 15 |
+
for the problem of m rankings. Annals of Mathematical Statistics, 11(1), 86-92.
|
| 16 |
+
Recommended protocol for ML comparison:
|
| 17 |
+
Demsar, J. (2006). Statistical comparisons of classifiers over multiple
|
| 18 |
+
data sets. Journal of Machine Learning Research, 7, 1-30.
|
| 19 |
+
|
| 20 |
+
- Nemenyi post-hoc pairwise test (Critical Difference diagram):
|
| 21 |
+
Nemenyi, P. (1963). Distribution-free multiple comparisons.
|
| 22 |
+
PhD thesis, Princeton University.
|
| 23 |
+
Applied per: Demsar (2006), JMLR 7:1-30.
|
| 24 |
+
|
| 25 |
+
- Wilcoxon signed-rank test (pairwise DAHS vs each baseline):
|
| 26 |
+
Wilcoxon, F. (1945). Individual comparisons by ranking methods.
|
| 27 |
+
Biometrics Bulletin, 1(6), 80-83. doi:10.2307/3001968.
|
| 28 |
+
|
| 29 |
+
- Cohen's d effect size:
|
| 30 |
+
Cohen, J. (1988). Statistical Power Analysis for the Behavioral
|
| 31 |
+
Sciences. Lawrence Erlbaum Associates (2nd ed.).
|
| 32 |
+
d > 0.2 small, d > 0.5 medium, d > 0.8 large.
|
| 33 |
+
|
| 34 |
+
- Holm-Bonferroni multiple comparison correction:
|
| 35 |
+
Holm, S. (1979). A simple sequentially rejective multiple test
|
| 36 |
+
procedure. Scandinavian Journal of Statistics, 6(2), 65-70.
|
| 37 |
+
|
| 38 |
+
- Bootstrap 95% CI (5,000 resamples):
|
| 39 |
+
Efron, B. & Tibshirani, R.J. (1993). An Introduction to the
|
| 40 |
+
Bootstrap. Chapman & Hall.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
from __future__ import annotations
|
| 44 |
+
|
| 45 |
+
import json
|
| 46 |
+
import logging
|
| 47 |
+
import math
|
| 48 |
+
import warnings
|
| 49 |
+
from pathlib import Path
|
| 50 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 51 |
+
|
| 52 |
+
import matplotlib
|
| 53 |
+
matplotlib.use("Agg")
|
| 54 |
+
import matplotlib.pyplot as plt
|
| 55 |
+
import numpy as np
|
| 56 |
+
import pandas as pd
|
| 57 |
+
from scipy import stats
|
| 58 |
+
|
| 59 |
+
warnings.filterwarnings("ignore")
|
| 60 |
+
|
| 61 |
+
logger = logging.getLogger(__name__)
|
| 62 |
+
|
| 63 |
+
RESULTS_DIR = Path(__file__).parent.parent / "results"
|
| 64 |
+
PLOTS_DIR = RESULTS_DIR / "plots"
|
| 65 |
+
MODELS_DIR = Path(__file__).parent.parent / "models"
|
| 66 |
+
|
| 67 |
+
HEURISTIC_NAMES = ["fifo", "priority_edd", "critical_ratio", "atc", "wspt", "slack"]
|
| 68 |
+
HEURISTIC_LABELS = ["FIFO", "Priority-EDD", "Critical-Ratio", "ATC", "WSPT", "Slack"]
|
| 69 |
+
|
| 70 |
+
DARK_BG = "#0f1117"
|
| 71 |
+
DARK_AX = "#1a1d27"
|
| 72 |
+
TEXT_COL = "#e0e0e0"
|
| 73 |
+
|
| 74 |
+
COLORS = ["#4fc3f7", "#81c784", "#ffb74d", "#e57373", "#ce93d8", "#80cbc4",
|
| 75 |
+
"#fff176", "#ff8a65", "#90caf9", "#f48fb1"]
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _dark_fig(figsize=(12, 7)):
|
| 79 |
+
fig, ax = plt.subplots(figsize=figsize)
|
| 80 |
+
fig.patch.set_facecolor(DARK_BG)
|
| 81 |
+
ax.set_facecolor(DARK_AX)
|
| 82 |
+
ax.tick_params(colors=TEXT_COL)
|
| 83 |
+
ax.xaxis.label.set_color(TEXT_COL)
|
| 84 |
+
ax.yaxis.label.set_color(TEXT_COL)
|
| 85 |
+
ax.title.set_color(TEXT_COL)
|
| 86 |
+
for spine in ax.spines.values():
|
| 87 |
+
spine.set_color("#333344")
|
| 88 |
+
return fig, ax
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _dark_fig_multi(rows=1, cols=2, figsize=(16, 7)):
|
| 92 |
+
fig, axes = plt.subplots(rows, cols, figsize=figsize)
|
| 93 |
+
fig.patch.set_facecolor(DARK_BG)
|
| 94 |
+
for ax in np.array(axes).flatten():
|
| 95 |
+
ax.set_facecolor(DARK_AX)
|
| 96 |
+
ax.tick_params(colors=TEXT_COL)
|
| 97 |
+
ax.xaxis.label.set_color(TEXT_COL)
|
| 98 |
+
ax.yaxis.label.set_color(TEXT_COL)
|
| 99 |
+
ax.title.set_color(TEXT_COL)
|
| 100 |
+
for spine in ax.spines.values():
|
| 101 |
+
spine.set_color("#333344")
|
| 102 |
+
return fig, axes
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def _norm_min_max(arr: np.ndarray) -> np.ndarray:
|
| 106 |
+
r = arr.max() - arr.min()
|
| 107 |
+
if r < 1e-10:
|
| 108 |
+
return np.zeros_like(arr)
|
| 109 |
+
return (arr - arr.min()) / r
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# ---------------------------------------------------------------------------
|
| 113 |
+
# Benchmark runner
|
| 114 |
+
# ---------------------------------------------------------------------------
|
| 115 |
+
|
| 116 |
+
def run_benchmark(
|
| 117 |
+
seeds: Optional[List[int]] = None,
|
| 118 |
+
n_workers: int = 4,
|
| 119 |
+
save_csv: bool = True,
|
| 120 |
+
) -> pd.DataFrame:
|
| 121 |
+
"""Run benchmark across all seeds × 9 methods.
|
| 122 |
+
|
| 123 |
+
Methods:
|
| 124 |
+
0-5: 6 baselines (FIFO, Priority-EDD, CR, ATC, WSPT, Slack)
|
| 125 |
+
6: Hybrid-Priority (GBR)
|
| 126 |
+
7: DAHS-RF (Random Forest selector)
|
| 127 |
+
8: DAHS-XGB (XGBoost selector)
|
| 128 |
+
"""
|
| 129 |
+
import multiprocessing as mp
|
| 130 |
+
from tqdm import tqdm
|
| 131 |
+
|
| 132 |
+
if seeds is None:
|
| 133 |
+
seeds = list(range(99000, 99300)) # 300 test seeds
|
| 134 |
+
|
| 135 |
+
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
|
| 136 |
+
PLOTS_DIR.mkdir(parents=True, exist_ok=True)
|
| 137 |
+
|
| 138 |
+
logger.info("Running benchmark: %d seeds × 9 methods", len(seeds))
|
| 139 |
+
|
| 140 |
+
all_args = [(seed,) for seed in seeds]
|
| 141 |
+
|
| 142 |
+
rows = []
|
| 143 |
+
ctx = mp.get_context("spawn")
|
| 144 |
+
with ctx.Pool(processes=n_workers) as pool:
|
| 145 |
+
for result in tqdm(
|
| 146 |
+
pool.imap_unordered(_benchmark_single_seed, all_args),
|
| 147 |
+
total=len(all_args),
|
| 148 |
+
desc="Benchmark",
|
| 149 |
+
):
|
| 150 |
+
rows.extend(result)
|
| 151 |
+
|
| 152 |
+
df = pd.DataFrame(rows)
|
| 153 |
+
logger.info("Benchmark complete: %s rows", len(df))
|
| 154 |
+
|
| 155 |
+
if save_csv:
|
| 156 |
+
path = RESULTS_DIR / "benchmark_results.csv"
|
| 157 |
+
df.to_csv(path, index=False)
|
| 158 |
+
logger.info("Saved -> %s", path)
|
| 159 |
+
|
| 160 |
+
return df
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def _row(seed: int, method: str, m: Any, elapsed: float) -> Dict[str, Any]:
|
| 164 |
+
"""Build one benchmark row from a SimulationMetrics + wall-clock seconds.
|
| 165 |
+
|
| 166 |
+
Wall-clock matters for paper review: a method that wins on tardiness but
|
| 167 |
+
is 50× slower than ATC isn't deployable. We capture it on every row so
|
| 168 |
+
"DAHS adds X ms per dispatch" claims are backed by data, not asserted.
|
| 169 |
+
"""
|
| 170 |
+
util_vals = list(m.zone_utilization.values())
|
| 171 |
+
return {
|
| 172 |
+
"seed": seed,
|
| 173 |
+
"method": method,
|
| 174 |
+
"makespan": m.makespan,
|
| 175 |
+
"total_tardiness": m.total_tardiness,
|
| 176 |
+
"sla_breach_rate": m.sla_breach_rate,
|
| 177 |
+
"avg_cycle_time": m.avg_cycle_time,
|
| 178 |
+
"zone_utilization_avg": float(np.mean(util_vals)) if util_vals else 0.0,
|
| 179 |
+
"throughput": m.throughput,
|
| 180 |
+
"queue_max": m.queue_max,
|
| 181 |
+
"completed_jobs": m.completed_jobs,
|
| 182 |
+
"elapsed_seconds": round(float(elapsed), 4),
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def _benchmark_single_seed(args: Tuple) -> List[Dict[str, Any]]:
|
| 187 |
+
"""Worker: run all methods on one seed and return their metric rows."""
|
| 188 |
+
(seed,) = args
|
| 189 |
+
import time as _time
|
| 190 |
+
from src.heuristics import (
|
| 191 |
+
fifo_dispatch, priority_edd_dispatch, critical_ratio_dispatch,
|
| 192 |
+
atc_dispatch, wspt_dispatch, slack_dispatch,
|
| 193 |
+
)
|
| 194 |
+
from src.simulator import WarehouseSimulator
|
| 195 |
+
from src.features import FeatureExtractor
|
| 196 |
+
|
| 197 |
+
rows: List[Dict[str, Any]] = []
|
| 198 |
+
methods = [
|
| 199 |
+
("fifo", fifo_dispatch),
|
| 200 |
+
("priority_edd", priority_edd_dispatch),
|
| 201 |
+
("critical_ratio", critical_ratio_dispatch),
|
| 202 |
+
("atc", atc_dispatch),
|
| 203 |
+
("wspt", wspt_dispatch),
|
| 204 |
+
("slack", slack_dispatch),
|
| 205 |
+
]
|
| 206 |
+
|
| 207 |
+
# Capture per-baseline tardiness/SLA/cycle/throughput on this seed so we
|
| 208 |
+
# can synthesise a "best fixed heuristic in hindsight" row at the end.
|
| 209 |
+
# An operator picking the post-hoc best fixed rule is the natural lower
|
| 210 |
+
# bound any learned scheduler must beat.
|
| 211 |
+
baseline_metrics: Dict[str, Any] = {}
|
| 212 |
+
|
| 213 |
+
for method_name, heur_fn in methods:
|
| 214 |
+
try:
|
| 215 |
+
fe = FeatureExtractor()
|
| 216 |
+
sim = WarehouseSimulator(seed=seed, heuristic_fn=heur_fn, feature_extractor=fe)
|
| 217 |
+
t0 = _time.perf_counter()
|
| 218 |
+
m = sim.run(duration=600.0)
|
| 219 |
+
elapsed = _time.perf_counter() - t0
|
| 220 |
+
rows.append(_row(seed, method_name, m, elapsed))
|
| 221 |
+
baseline_metrics[method_name] = m
|
| 222 |
+
except Exception as e:
|
| 223 |
+
logger.warning("[%s] %s failed: %s", seed, method_name, e)
|
| 224 |
+
|
| 225 |
+
# Best-fixed-in-hindsight oracle: minimum tardiness across the six fixed
|
| 226 |
+
# rules. For non-tardiness metrics we copy the corresponding metric from
|
| 227 |
+
# the same winning method so SLA/cycle/throughput stay self-consistent.
|
| 228 |
+
if baseline_metrics:
|
| 229 |
+
winner_name = min(
|
| 230 |
+
baseline_metrics,
|
| 231 |
+
key=lambda k: baseline_metrics[k].total_tardiness,
|
| 232 |
+
)
|
| 233 |
+
wm = baseline_metrics[winner_name]
|
| 234 |
+
rows.append({
|
| 235 |
+
**_row(seed, "best_fixed_oracle", wm, 0.0),
|
| 236 |
+
"best_fixed_winner": winner_name,
|
| 237 |
+
})
|
| 238 |
+
|
| 239 |
+
# Try hybrid methods if models exist.
|
| 240 |
+
# For each trained model we run TWO variants:
|
| 241 |
+
# dahs_{name} — greedy ML only (BatchwiseSelector), ablation baseline
|
| 242 |
+
# dahs_hybrid_{name} — ML + rolling-horizon fork oracle (guarantees ≥ best fixed)
|
| 243 |
+
for model_name in ("rf", "xgb"):
|
| 244 |
+
model_path = MODELS_DIR / f"selector_{model_name}.joblib"
|
| 245 |
+
if not model_path.exists():
|
| 246 |
+
continue
|
| 247 |
+
try:
|
| 248 |
+
import joblib
|
| 249 |
+
from src.hybrid_scheduler import BatchwiseSelector, RollingHorizonOracle
|
| 250 |
+
|
| 251 |
+
model = joblib.load(model_path)
|
| 252 |
+
|
| 253 |
+
# ── (a) ML-only (greedy) — shows ML alone is insufficient ─────
|
| 254 |
+
fe = FeatureExtractor()
|
| 255 |
+
selector = BatchwiseSelector(model=model, feature_extractor=fe)
|
| 256 |
+
sim = WarehouseSimulator(seed=seed, heuristic_fn=fifo_dispatch, feature_extractor=fe)
|
| 257 |
+
|
| 258 |
+
def make_dispatch(sel, s):
|
| 259 |
+
def _dispatch(jobs, t, zone_id):
|
| 260 |
+
sel.update_state(s.get_state_snapshot())
|
| 261 |
+
return sel.dispatch(jobs, t, zone_id)
|
| 262 |
+
return _dispatch
|
| 263 |
+
|
| 264 |
+
sim.heuristic_fn = make_dispatch(selector, sim)
|
| 265 |
+
t0 = _time.perf_counter()
|
| 266 |
+
m = sim.run(duration=600.0)
|
| 267 |
+
rows.append(_row(seed, f"dahs_{model_name}", m, _time.perf_counter() - t0))
|
| 268 |
+
|
| 269 |
+
# ── (b) Hybrid = ML prior + fork oracle (the guarantee) ────────
|
| 270 |
+
fe2 = FeatureExtractor()
|
| 271 |
+
oracle = RollingHorizonOracle(ml_model=model, feature_extractor=fe2)
|
| 272 |
+
sim2 = WarehouseSimulator(seed=seed, heuristic_fn=fifo_dispatch, feature_extractor=fe2)
|
| 273 |
+
oracle.attach_simulator(sim2)
|
| 274 |
+
sim2.heuristic_fn = lambda jobs, t, z: oracle.dispatch(jobs, t, z)
|
| 275 |
+
t0 = _time.perf_counter()
|
| 276 |
+
m2 = sim2.run(duration=600.0)
|
| 277 |
+
rows.append(_row(seed, f"dahs_hybrid_{model_name}", m2, _time.perf_counter() - t0))
|
| 278 |
+
except Exception as e:
|
| 279 |
+
logger.warning("[%s] dahs_%s failed: %s", seed, model_name, e)
|
| 280 |
+
|
| 281 |
+
# ── DAHS-Oracle: pure fork oracle, no ML (theoretical ceiling) ──────
|
| 282 |
+
try:
|
| 283 |
+
from src.hybrid_scheduler import RollingHorizonOracle
|
| 284 |
+
|
| 285 |
+
feo = FeatureExtractor()
|
| 286 |
+
oracle = RollingHorizonOracle(ml_model=None, feature_extractor=None)
|
| 287 |
+
simo = WarehouseSimulator(seed=seed, heuristic_fn=fifo_dispatch, feature_extractor=feo)
|
| 288 |
+
oracle.attach_simulator(simo)
|
| 289 |
+
simo.heuristic_fn = lambda jobs, t, z: oracle.dispatch(jobs, t, z)
|
| 290 |
+
t0 = _time.perf_counter()
|
| 291 |
+
mo = simo.run(duration=600.0)
|
| 292 |
+
rows.append(_row(seed, "dahs_oracle", mo, _time.perf_counter() - t0))
|
| 293 |
+
except Exception as e:
|
| 294 |
+
logger.warning("[%s] dahs_oracle failed: %s", seed, e)
|
| 295 |
+
|
| 296 |
+
# Priority hybrid (per-job GBR scorer). NOTE: held last in the headline
|
| 297 |
+
# priority list because its training CV R² was 0.022 ± 0.717 — keep it
|
| 298 |
+
# in the benchmark for completeness/ablation but do not let it lead.
|
| 299 |
+
priority_path = MODELS_DIR / "priority_gbr.joblib"
|
| 300 |
+
if priority_path.exists():
|
| 301 |
+
try:
|
| 302 |
+
import joblib
|
| 303 |
+
from src.hybrid_scheduler import HybridPriority
|
| 304 |
+
|
| 305 |
+
fe = FeatureExtractor()
|
| 306 |
+
priority = HybridPriority(model_path=priority_path, feature_extractor=fe)
|
| 307 |
+
sim = WarehouseSimulator(seed=seed, heuristic_fn=fifo_dispatch, feature_extractor=fe)
|
| 308 |
+
|
| 309 |
+
def _priority_dispatch(jobs, t, zone_id):
|
| 310 |
+
priority.update_state(sim.get_state_snapshot())
|
| 311 |
+
return priority(jobs, t, zone_id)
|
| 312 |
+
|
| 313 |
+
sim.heuristic_fn = _priority_dispatch
|
| 314 |
+
t0 = _time.perf_counter()
|
| 315 |
+
m = sim.run(duration=600.0)
|
| 316 |
+
rows.append(_row(seed, "hybrid_priority", m, _time.perf_counter() - t0))
|
| 317 |
+
except Exception as e:
|
| 318 |
+
logger.warning("[%s] hybrid_priority failed: %s", seed, e)
|
| 319 |
+
|
| 320 |
+
return rows
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
# ---------------------------------------------------------------------------
|
| 324 |
+
# Statistical analysis
|
| 325 |
+
# ---------------------------------------------------------------------------
|
| 326 |
+
|
| 327 |
+
# Direction of preference per metric. "lower" means smaller value is better
|
| 328 |
+
# (e.g. tardiness, SLA breach, cycle time); "higher" means larger is better
|
| 329 |
+
# (throughput, utilization). Used to set the alternative for the one-sided
|
| 330 |
+
# Wilcoxon and to sign Cohen's d so a positive value always means "DAHS wins."
|
| 331 |
+
METRIC_DIRECTIONS: Dict[str, str] = {
|
| 332 |
+
"total_tardiness": "lower",
|
| 333 |
+
"sla_breach_rate": "lower",
|
| 334 |
+
"avg_cycle_time": "lower",
|
| 335 |
+
"makespan": "lower",
|
| 336 |
+
"throughput": "higher",
|
| 337 |
+
"zone_utilization_avg": "higher",
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def _wilcoxon_for_metric(
|
| 342 |
+
pivot: pd.DataFrame,
|
| 343 |
+
available_methods: List[str],
|
| 344 |
+
dahs_col: str,
|
| 345 |
+
metric: str,
|
| 346 |
+
direction: str,
|
| 347 |
+
) -> List[Dict[str, Any]]:
|
| 348 |
+
"""One-sided Wilcoxon DAHS-vs-baseline for a single metric.
|
| 349 |
+
|
| 350 |
+
Lower-is-better metrics test H1: baseline > DAHS, so a small p-value means
|
| 351 |
+
DAHS is significantly *lower* (better). Higher-is-better metrics test
|
| 352 |
+
H1: DAHS > baseline. `diff` is always (better-side - worse-side) so the
|
| 353 |
+
resulting Cohen's d is positive when DAHS wins, negative when it loses.
|
| 354 |
+
Holm-Bonferroni is applied within each metric family by the caller.
|
| 355 |
+
"""
|
| 356 |
+
rows: List[Dict[str, Any]] = []
|
| 357 |
+
if dahs_col not in pivot.columns:
|
| 358 |
+
return rows
|
| 359 |
+
dahs_vals = pivot[dahs_col].values
|
| 360 |
+
for method in available_methods:
|
| 361 |
+
if method == dahs_col:
|
| 362 |
+
continue
|
| 363 |
+
try:
|
| 364 |
+
base_vals = pivot[method].values
|
| 365 |
+
if direction == "lower":
|
| 366 |
+
stat, p = stats.wilcoxon(base_vals, dahs_vals, alternative="greater")
|
| 367 |
+
diff = base_vals - dahs_vals
|
| 368 |
+
else:
|
| 369 |
+
stat, p = stats.wilcoxon(dahs_vals, base_vals, alternative="greater")
|
| 370 |
+
diff = dahs_vals - base_vals
|
| 371 |
+
d = float(np.mean(diff) / (np.std(diff) + 1e-10))
|
| 372 |
+
boot_means = [
|
| 373 |
+
np.mean(np.random.choice(diff, size=len(diff), replace=True))
|
| 374 |
+
for _ in range(5000)
|
| 375 |
+
]
|
| 376 |
+
ci_lo, ci_hi = np.percentile(boot_means, [2.5, 97.5])
|
| 377 |
+
rows.append({
|
| 378 |
+
"metric": metric,
|
| 379 |
+
"direction": direction,
|
| 380 |
+
"baseline": method,
|
| 381 |
+
"dahs": dahs_col,
|
| 382 |
+
"statistic": round(float(stat), 4),
|
| 383 |
+
"p_value": float(p),
|
| 384 |
+
"significant_holm": False,
|
| 385 |
+
"cohens_d": round(d, 4),
|
| 386 |
+
"ci_95_lo": round(float(ci_lo), 4),
|
| 387 |
+
"ci_95_hi": round(float(ci_hi), 4),
|
| 388 |
+
})
|
| 389 |
+
except Exception as exc:
|
| 390 |
+
logger.warning("Wilcoxon failed for %s on %s: %s", method, metric, exc)
|
| 391 |
+
if rows:
|
| 392 |
+
ps = [r["p_value"] for r in rows]
|
| 393 |
+
n = len(ps)
|
| 394 |
+
order = np.argsort(ps)
|
| 395 |
+
for rank, idx in enumerate(order):
|
| 396 |
+
rows[idx]["significant_holm"] = ps[idx] < (0.05 / (n - rank))
|
| 397 |
+
return rows
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def _nemenyi_critical_difference(k: int, n: int, alpha: float = 0.05) -> float:
|
| 401 |
+
"""Nemenyi critical-difference for k methods over n datasets at alpha=0.05.
|
| 402 |
+
|
| 403 |
+
CD = q_alpha * sqrt(k*(k+1) / (6*n)) per Demsar (2006), JMLR 7:1-30.
|
| 404 |
+
"""
|
| 405 |
+
Q_05 = {
|
| 406 |
+
2: 1.960, 3: 2.343, 4: 2.569, 5: 2.728, 6: 2.850, 7: 2.949,
|
| 407 |
+
8: 3.031, 9: 3.102, 10: 3.164,
|
| 408 |
+
}
|
| 409 |
+
q = Q_05.get(k, Q_05[10] + 0.05 * (k - 10))
|
| 410 |
+
return float(q * math.sqrt(k * (k + 1) / (6.0 * n)))
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def _nemenyi_pairwise(pivot: pd.DataFrame, available_methods: List[str]) -> Dict[str, Any]:
|
| 414 |
+
"""Nemenyi pairwise comparisons + critical difference for the primary metric."""
|
| 415 |
+
if len(available_methods) < 3 or pivot.shape[0] < 2:
|
| 416 |
+
return {"available": False, "reason": "need >=3 methods and >=2 seeds"}
|
| 417 |
+
|
| 418 |
+
ranks = pivot[available_methods].rank(axis=1, method="average")
|
| 419 |
+
mean_ranks = ranks.mean(axis=0).to_dict()
|
| 420 |
+
n_seeds = ranks.shape[0]
|
| 421 |
+
k = len(available_methods)
|
| 422 |
+
cd = _nemenyi_critical_difference(k, n_seeds)
|
| 423 |
+
|
| 424 |
+
matrix: List[Dict[str, Any]] = []
|
| 425 |
+
for i, mi in enumerate(available_methods):
|
| 426 |
+
for j, mj in enumerate(available_methods):
|
| 427 |
+
if j <= i:
|
| 428 |
+
continue
|
| 429 |
+
diff = abs(mean_ranks[mi] - mean_ranks[mj])
|
| 430 |
+
matrix.append({
|
| 431 |
+
"method_a": mi,
|
| 432 |
+
"method_b": mj,
|
| 433 |
+
"rank_a": round(float(mean_ranks[mi]), 4),
|
| 434 |
+
"rank_b": round(float(mean_ranks[mj]), 4),
|
| 435 |
+
"rank_diff": round(float(diff), 4),
|
| 436 |
+
"significant": bool(diff > cd),
|
| 437 |
+
})
|
| 438 |
+
return {
|
| 439 |
+
"available": True,
|
| 440 |
+
"alpha": 0.05,
|
| 441 |
+
"k": k,
|
| 442 |
+
"n_seeds": n_seeds,
|
| 443 |
+
"critical_difference": round(cd, 4),
|
| 444 |
+
"mean_ranks": {m: round(float(r), 4) for m, r in mean_ranks.items()},
|
| 445 |
+
"pairwise": matrix,
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
def _plot_critical_difference_diagram(nemenyi: Dict[str, Any]) -> None:
|
| 450 |
+
"""Render a Demsar-style critical-difference diagram at results/plots/cd_diagram.png."""
|
| 451 |
+
if not nemenyi.get("available"):
|
| 452 |
+
return
|
| 453 |
+
mean_ranks: Dict[str, float] = nemenyi["mean_ranks"]
|
| 454 |
+
cd: float = nemenyi["critical_difference"]
|
| 455 |
+
methods = sorted(mean_ranks.keys(), key=lambda m: mean_ranks[m])
|
| 456 |
+
ranks = [mean_ranks[m] for m in methods]
|
| 457 |
+
k = len(methods)
|
| 458 |
+
|
| 459 |
+
PLOTS_DIR.mkdir(parents=True, exist_ok=True)
|
| 460 |
+
fig, ax = _dark_fig(figsize=(12, 4 + 0.3 * k))
|
| 461 |
+
rank_min = min(ranks) - 0.5
|
| 462 |
+
rank_max = max(ranks) + 0.5
|
| 463 |
+
ax.set_xlim(rank_min, rank_max)
|
| 464 |
+
ax.set_ylim(0, k + 1)
|
| 465 |
+
ax.invert_xaxis()
|
| 466 |
+
ax.get_yaxis().set_visible(False)
|
| 467 |
+
for side in ("left", "right", "top"):
|
| 468 |
+
ax.spines[side].set_visible(False)
|
| 469 |
+
|
| 470 |
+
for i, m in enumerate(methods):
|
| 471 |
+
y = k - i
|
| 472 |
+
x = mean_ranks[m]
|
| 473 |
+
ax.plot([rank_min, x], [y, y], color="#445", linewidth=0.75)
|
| 474 |
+
ax.plot([x], [y], "o", color=COLORS[i % len(COLORS)], markersize=8)
|
| 475 |
+
ax.text(rank_min - 0.05 * (rank_max - rank_min), y,
|
| 476 |
+
f"{m} (rank {x:.2f})",
|
| 477 |
+
ha="right", va="center", color=TEXT_COL, fontsize=10)
|
| 478 |
+
|
| 479 |
+
cd_y = 0.5
|
| 480 |
+
ax.plot([min(ranks), min(ranks) + cd], [cd_y, cd_y], color="#e57373", linewidth=2.5)
|
| 481 |
+
ax.text(min(ranks) + cd / 2, cd_y - 0.25,
|
| 482 |
+
f"CD = {cd:.3f} (Nemenyi, α=0.05)",
|
| 483 |
+
ha="center", va="top", color="#e57373", fontsize=10)
|
| 484 |
+
|
| 485 |
+
ax.set_xlabel("Mean rank (lower = better)")
|
| 486 |
+
ax.set_title("Critical-Difference Diagram — total_tardiness", color=TEXT_COL, fontsize=13)
|
| 487 |
+
plt.tight_layout()
|
| 488 |
+
plt.savefig(PLOTS_DIR / "cd_diagram.png", dpi=150, facecolor=DARK_BG)
|
| 489 |
+
plt.close()
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
def run_statistical_analysis(df: pd.DataFrame) -> Dict[str, Any]:
|
| 493 |
+
"""Run Friedman, Nemenyi post-hoc, direction-aware Wilcoxon, Cohen's d.
|
| 494 |
+
|
| 495 |
+
See Demsar (2006) JMLR 7:1-30 for the full protocol. The Wilcoxon test is
|
| 496 |
+
direction-aware: for lower-is-better metrics the alternative is
|
| 497 |
+
H1: baseline > DAHS; for higher-is-better metrics it is H1: DAHS > baseline.
|
| 498 |
+
Cohen's d is signed so positive d always means DAHS wins.
|
| 499 |
+
Holm-Bonferroni controls FWER within each metric family.
|
| 500 |
+
"""
|
| 501 |
+
methods = sorted(df["method"].unique())
|
| 502 |
+
|
| 503 |
+
primary_metric = "total_tardiness"
|
| 504 |
+
pivot = df.pivot_table(index="seed", columns="method", values=primary_metric)
|
| 505 |
+
pivot.dropna(inplace=True)
|
| 506 |
+
|
| 507 |
+
available_methods = [m for m in methods if m in pivot.columns]
|
| 508 |
+
|
| 509 |
+
results: Dict[str, Any] = {"primary_metric": primary_metric}
|
| 510 |
+
|
| 511 |
+
try:
|
| 512 |
+
data_arrays = [pivot[m].values for m in available_methods]
|
| 513 |
+
stat, p = stats.friedmanchisquare(*data_arrays)
|
| 514 |
+
results["friedman"] = {
|
| 515 |
+
"statistic": round(float(stat), 4),
|
| 516 |
+
"p_value": float(p),
|
| 517 |
+
"significant": bool(p < 0.05),
|
| 518 |
+
"metric": primary_metric,
|
| 519 |
+
}
|
| 520 |
+
logger.info("Friedman test: chi2=%.4f, p=%.6f", stat, p)
|
| 521 |
+
except Exception as e:
|
| 522 |
+
results["friedman"] = {"error": str(e)}
|
| 523 |
+
|
| 524 |
+
try:
|
| 525 |
+
nemenyi = _nemenyi_pairwise(pivot, available_methods)
|
| 526 |
+
results["nemenyi"] = nemenyi
|
| 527 |
+
if nemenyi.get("available"):
|
| 528 |
+
_plot_critical_difference_diagram(nemenyi)
|
| 529 |
+
logger.info("Nemenyi: CD=%.4f over k=%d methods, n=%d seeds",
|
| 530 |
+
nemenyi["critical_difference"], nemenyi["k"], nemenyi["n_seeds"])
|
| 531 |
+
except Exception as e:
|
| 532 |
+
results["nemenyi"] = {"error": str(e)}
|
| 533 |
+
|
| 534 |
+
# Pick the headline DAHS column. Order = best evidence first:
|
| 535 |
+
# 1. dahs_hybrid_* — ML prior + rolling-horizon fork oracle, the
|
| 536 |
+
# method we want the paper to highlight (guarantees
|
| 537 |
+
# at least best-fixed in expectation).
|
| 538 |
+
# 2. dahs_oracle — pure fork oracle, the upper-bound ablation.
|
| 539 |
+
# 3. dahs_* — greedy ML-only (BatchwiseSelector) ablation.
|
| 540 |
+
# 4. hybrid_priority — per-job GBR scorer; held LAST because its
|
| 541 |
+
# training CV R² was 0.022 ± 0.717. Keep it in
|
| 542 |
+
# the benchmark for completeness but do not let
|
| 543 |
+
# it lead headline numbers until regularised.
|
| 544 |
+
_priority = [
|
| 545 |
+
"dahs_hybrid_xgb", "dahs_hybrid_rf",
|
| 546 |
+
"dahs_oracle",
|
| 547 |
+
"dahs_xgb", "dahs_rf",
|
| 548 |
+
"hybrid_priority",
|
| 549 |
+
]
|
| 550 |
+
dahs_col = next((c for c in _priority if c in available_methods), None)
|
| 551 |
+
results["headline_method"] = dahs_col
|
| 552 |
+
if dahs_col is None:
|
| 553 |
+
results["wilcoxon"] = []
|
| 554 |
+
results["wilcoxon_secondary"] = {}
|
| 555 |
+
results["per_seed_dominance"] = {}
|
| 556 |
+
else:
|
| 557 |
+
results["wilcoxon"] = _wilcoxon_for_metric(
|
| 558 |
+
pivot, available_methods, dahs_col,
|
| 559 |
+
primary_metric, METRIC_DIRECTIONS[primary_metric],
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
# Per-seed dominance: on what fraction of seeds does the headline
|
| 563 |
+
# DAHS method beat each baseline on tardiness? This is the honest
|
| 564 |
+
# answer to the "does it win on every seed" question.
|
| 565 |
+
dominance: Dict[str, Any] = {"n_seeds": int(pivot.shape[0])}
|
| 566 |
+
per_baseline: Dict[str, Dict[str, Any]] = {}
|
| 567 |
+
beats_strongest_seeds = 0
|
| 568 |
+
# Identify "best baseline per seed" so we can compute win-rate vs
|
| 569 |
+
# the per-seed best fixed rule (the hardest comparison).
|
| 570 |
+
baseline_only = [m for m in available_methods
|
| 571 |
+
if m not in (
|
| 572 |
+
"dahs_xgb", "dahs_rf",
|
| 573 |
+
"dahs_hybrid_xgb", "dahs_hybrid_rf",
|
| 574 |
+
"dahs_oracle", "hybrid_priority",
|
| 575 |
+
"best_fixed_oracle",
|
| 576 |
+
)]
|
| 577 |
+
for method in available_methods:
|
| 578 |
+
if method == dahs_col:
|
| 579 |
+
continue
|
| 580 |
+
wins = int((pivot[dahs_col] < pivot[method]).sum())
|
| 581 |
+
ties = int((pivot[dahs_col] == pivot[method]).sum())
|
| 582 |
+
per_baseline[method] = {
|
| 583 |
+
"wins": wins,
|
| 584 |
+
"ties": ties,
|
| 585 |
+
"losses": int(pivot.shape[0] - wins - ties),
|
| 586 |
+
"win_rate": round(wins / max(pivot.shape[0], 1), 4),
|
| 587 |
+
}
|
| 588 |
+
if baseline_only:
|
| 589 |
+
best_per_seed = pivot[baseline_only].min(axis=1)
|
| 590 |
+
beats_strongest_seeds = int((pivot[dahs_col] < best_per_seed).sum())
|
| 591 |
+
dominance["wins_vs_best_fixed_per_seed"] = beats_strongest_seeds
|
| 592 |
+
dominance["win_rate_vs_best_fixed_per_seed"] = round(
|
| 593 |
+
beats_strongest_seeds / max(pivot.shape[0], 1), 4
|
| 594 |
+
)
|
| 595 |
+
dominance["per_baseline"] = per_baseline
|
| 596 |
+
results["per_seed_dominance"] = dominance
|
| 597 |
+
secondary: Dict[str, List[Dict[str, Any]]] = {}
|
| 598 |
+
for metric, direction in METRIC_DIRECTIONS.items():
|
| 599 |
+
if metric == primary_metric:
|
| 600 |
+
continue
|
| 601 |
+
piv_m = df.pivot_table(index="seed", columns="method", values=metric).dropna()
|
| 602 |
+
avail_m = [m for m in methods if m in piv_m.columns]
|
| 603 |
+
if dahs_col not in avail_m:
|
| 604 |
+
continue
|
| 605 |
+
secondary[metric] = _wilcoxon_for_metric(
|
| 606 |
+
piv_m, avail_m, dahs_col, metric, direction
|
| 607 |
+
)
|
| 608 |
+
results["wilcoxon_secondary"] = secondary
|
| 609 |
+
|
| 610 |
+
summary = []
|
| 611 |
+
for method in available_methods:
|
| 612 |
+
method_df = df[df["method"] == method]
|
| 613 |
+
summary.append({
|
| 614 |
+
"method": method,
|
| 615 |
+
"n": len(method_df),
|
| 616 |
+
"makespan_mean": round(float(method_df["makespan"].mean()), 2),
|
| 617 |
+
"makespan_std": round(float(method_df["makespan"].std()), 2),
|
| 618 |
+
"tardiness_mean": round(float(method_df["total_tardiness"].mean()), 2),
|
| 619 |
+
"tardiness_std": round(float(method_df["total_tardiness"].std()), 2),
|
| 620 |
+
"sla_mean": round(float(method_df["sla_breach_rate"].mean()), 4),
|
| 621 |
+
"cycle_mean": round(float(method_df["avg_cycle_time"].mean()), 2),
|
| 622 |
+
"throughput_mean": round(float(method_df["throughput"].mean()), 2),
|
| 623 |
+
})
|
| 624 |
+
results["summary"] = summary
|
| 625 |
+
|
| 626 |
+
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
|
| 627 |
+
with open(RESULTS_DIR / "statistical_tests.json", "w") as f:
|
| 628 |
+
json.dump(results, f, indent=2)
|
| 629 |
+
logger.info("Saved statistical_tests.json")
|
| 630 |
+
|
| 631 |
+
return results
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
# ---------------------------------------------------------------------------
|
| 635 |
+
# Switching analysis (NEW in DAHS_2)
|
| 636 |
+
# ---------------------------------------------------------------------------
|
| 637 |
+
|
| 638 |
+
def run_switching_analysis(df: pd.DataFrame) -> Dict[str, Any]:
|
| 639 |
+
"""Analyze DAHS switching behavior by running sample seeds with switching logs enabled."""
|
| 640 |
+
from src.heuristics import fifo_dispatch
|
| 641 |
+
from src.simulator import WarehouseSimulator
|
| 642 |
+
from src.features import FeatureExtractor
|
| 643 |
+
from src.hybrid_scheduler import BatchwiseSelector
|
| 644 |
+
import joblib as _joblib
|
| 645 |
+
|
| 646 |
+
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
|
| 647 |
+
|
| 648 |
+
sample_seeds = list(range(99000, 99010)) # 10 representative seeds
|
| 649 |
+
per_model: Dict[str, Any] = {}
|
| 650 |
+
|
| 651 |
+
for model_name in ("rf", "xgb"):
|
| 652 |
+
model_path = MODELS_DIR / f"selector_{model_name}.joblib"
|
| 653 |
+
if not model_path.exists():
|
| 654 |
+
logger.warning("Model not found: %s", model_path)
|
| 655 |
+
continue
|
| 656 |
+
|
| 657 |
+
model = _joblib.load(model_path)
|
| 658 |
+
total_evals = 0
|
| 659 |
+
total_switches = 0
|
| 660 |
+
total_hysteresis = 0
|
| 661 |
+
total_guardrails = 0
|
| 662 |
+
heuristic_counts: Dict[str, int] = {}
|
| 663 |
+
|
| 664 |
+
for seed in sample_seeds:
|
| 665 |
+
try:
|
| 666 |
+
fe = FeatureExtractor()
|
| 667 |
+
selector = BatchwiseSelector(model=model, feature_extractor=fe)
|
| 668 |
+
|
| 669 |
+
sim = WarehouseSimulator(seed=seed, heuristic_fn=fifo_dispatch, feature_extractor=fe)
|
| 670 |
+
|
| 671 |
+
def _make_dispatch(sel, s):
|
| 672 |
+
def _d(jobs, t, zone_id):
|
| 673 |
+
sel.update_state(s.get_state_snapshot())
|
| 674 |
+
return sel.dispatch(jobs, t, zone_id)
|
| 675 |
+
return _d
|
| 676 |
+
|
| 677 |
+
sim.heuristic_fn = _make_dispatch(selector, sim)
|
| 678 |
+
sim.run(duration=600.0)
|
| 679 |
+
|
| 680 |
+
summary = selector.switching_log.summary()
|
| 681 |
+
n_evals = summary.get("totalEvaluations", 0)
|
| 682 |
+
total_evals += n_evals
|
| 683 |
+
total_switches += summary.get("switchCount", 0)
|
| 684 |
+
total_hysteresis += summary.get("hysteresisBlocked", 0)
|
| 685 |
+
total_guardrails += summary.get("guardrailActivations", 0)
|
| 686 |
+
for h, frac in summary.get("distribution", {}).items():
|
| 687 |
+
heuristic_counts[h] = heuristic_counts.get(h, 0) + int(round(n_evals * frac))
|
| 688 |
+
|
| 689 |
+
except Exception as e:
|
| 690 |
+
logger.warning("Switching analysis seed %d (%s) failed: %s", seed, model_name, e)
|
| 691 |
+
|
| 692 |
+
n = len(sample_seeds)
|
| 693 |
+
total_h = sum(heuristic_counts.values())
|
| 694 |
+
per_model[f"dahs_{model_name}"] = {
|
| 695 |
+
"sample_seeds": n,
|
| 696 |
+
"avg_evaluations_per_run": round(total_evals / max(n, 1), 1),
|
| 697 |
+
"avg_switches_per_run": round(total_switches / max(n, 1), 1),
|
| 698 |
+
"avg_hysteresis_blocked_per_run": round(total_hysteresis / max(n, 1), 1),
|
| 699 |
+
"avg_guardrail_activations_per_run": round(total_guardrails / max(n, 1), 1),
|
| 700 |
+
"switching_rate_per_interval": round(total_switches / max(total_evals - n, 1), 4),
|
| 701 |
+
"heuristic_selection_distribution": {
|
| 702 |
+
h: round(c / max(total_h, 1), 4)
|
| 703 |
+
for h, c in sorted(heuristic_counts.items())
|
| 704 |
+
},
|
| 705 |
+
}
|
| 706 |
+
|
| 707 |
+
analysis = {
|
| 708 |
+
"description": "DAHS_2 batch-wise switching analysis (15-min intervals)",
|
| 709 |
+
**per_model,
|
| 710 |
+
}
|
| 711 |
+
|
| 712 |
+
with open(RESULTS_DIR / "switching_analysis.json", "w") as f:
|
| 713 |
+
json.dump(analysis, f, indent=2)
|
| 714 |
+
logger.info("Saved switching_analysis.json")
|
| 715 |
+
|
| 716 |
+
return analysis
|
| 717 |
+
|
| 718 |
+
|
| 719 |
+
# ---------------------------------------------------------------------------
|
| 720 |
+
# JSON export for frontend
|
| 721 |
+
# ---------------------------------------------------------------------------
|
| 722 |
+
|
| 723 |
+
def export_benchmark_json(df: pd.DataFrame) -> None:
|
| 724 |
+
"""Export summary JSON for the Results page frontend."""
|
| 725 |
+
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
|
| 726 |
+
|
| 727 |
+
methods = sorted(df["method"].unique())
|
| 728 |
+
summary = []
|
| 729 |
+
for method in methods:
|
| 730 |
+
mdf = df[df["method"] == method]
|
| 731 |
+
summary.append({
|
| 732 |
+
"method": method,
|
| 733 |
+
"n": len(mdf),
|
| 734 |
+
"tardiness": {"mean": float(mdf["total_tardiness"].mean()), "std": float(mdf["total_tardiness"].std())},
|
| 735 |
+
"sla": {"mean": float(mdf["sla_breach_rate"].mean()), "std": float(mdf["sla_breach_rate"].std())},
|
| 736 |
+
"cycle": {"mean": float(mdf["avg_cycle_time"].mean()), "std": float(mdf["avg_cycle_time"].std())},
|
| 737 |
+
"throughput": {"mean": float(mdf["throughput"].mean()), "std": float(mdf["throughput"].std())},
|
| 738 |
+
"makespan": {"mean": float(mdf["makespan"].mean()), "std": float(mdf["makespan"].std())},
|
| 739 |
+
})
|
| 740 |
+
|
| 741 |
+
with open(RESULTS_DIR / "benchmark_summary.json", "w") as f:
|
| 742 |
+
json.dump(summary, f, indent=2)
|
| 743 |
+
logger.info("Saved benchmark_summary.json")
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
# ---------------------------------------------------------------------------
|
| 747 |
+
# Plots (11 dark-theme plots)
|
| 748 |
+
# ---------------------------------------------------------------------------
|
| 749 |
+
|
| 750 |
+
def generate_plots(df: pd.DataFrame) -> None:
|
| 751 |
+
"""Generate all 11 dark-theme benchmark plots."""
|
| 752 |
+
PLOTS_DIR.mkdir(parents=True, exist_ok=True)
|
| 753 |
+
|
| 754 |
+
methods = sorted(df["method"].unique())
|
| 755 |
+
method_colors = {m: COLORS[i % len(COLORS)] for i, m in enumerate(methods)}
|
| 756 |
+
|
| 757 |
+
# 1. Tardiness boxplot
|
| 758 |
+
fig, ax = _dark_fig(figsize=(14, 7))
|
| 759 |
+
data_by_method = [df[df["method"] == m]["total_tardiness"].dropna().values for m in methods]
|
| 760 |
+
bp = ax.boxplot(data_by_method, labels=methods, patch_artist=True)
|
| 761 |
+
for patch, method in zip(bp["boxes"], methods):
|
| 762 |
+
patch.set_facecolor(method_colors[method])
|
| 763 |
+
patch.set_alpha(0.75)
|
| 764 |
+
ax.set_title("Total Tardiness — All Methods", fontsize=14)
|
| 765 |
+
ax.set_xlabel("Method")
|
| 766 |
+
ax.set_ylabel("Total Tardiness (min)")
|
| 767 |
+
ax.tick_params(axis="x", rotation=35)
|
| 768 |
+
plt.tight_layout()
|
| 769 |
+
plt.savefig(PLOTS_DIR / "benchmark_tardiness.png", dpi=150, facecolor=DARK_BG)
|
| 770 |
+
plt.close()
|
| 771 |
+
|
| 772 |
+
# 2. SLA breach bar chart
|
| 773 |
+
fig, ax = _dark_fig(figsize=(12, 6))
|
| 774 |
+
sla_means = [df[df["method"] == m]["sla_breach_rate"].mean() * 100 for m in methods]
|
| 775 |
+
bars = ax.bar(methods, sla_means, color=[method_colors[m] for m in methods], alpha=0.85)
|
| 776 |
+
ax.set_title("Average SLA Breach Rate", fontsize=14)
|
| 777 |
+
ax.set_ylabel("SLA Breach Rate (%)")
|
| 778 |
+
ax.tick_params(axis="x", rotation=35)
|
| 779 |
+
for bar, val in zip(bars, sla_means):
|
| 780 |
+
ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.3,
|
| 781 |
+
f"{val:.1f}%", ha="center", va="bottom", color=TEXT_COL, fontsize=9)
|
| 782 |
+
plt.tight_layout()
|
| 783 |
+
plt.savefig(PLOTS_DIR / "sla_breach_bar.png", dpi=150, facecolor=DARK_BG)
|
| 784 |
+
plt.close()
|
| 785 |
+
|
| 786 |
+
# 3. Zone utilization heatmap
|
| 787 |
+
try:
|
| 788 |
+
fig, ax = _dark_fig(figsize=(10, 6))
|
| 789 |
+
util_data = []
|
| 790 |
+
for m in methods:
|
| 791 |
+
mdf = df[df["method"] == m]
|
| 792 |
+
util_data.append([mdf["zone_utilization_avg"].mean()])
|
| 793 |
+
import seaborn as sns
|
| 794 |
+
sns.set_style("dark")
|
| 795 |
+
hm = ax.imshow([[v[0] for v in util_data]], aspect="auto", cmap="coolwarm")
|
| 796 |
+
ax.set_xticks(range(len(methods)))
|
| 797 |
+
ax.set_xticklabels(methods, rotation=35)
|
| 798 |
+
ax.set_yticklabels(["Avg Util"])
|
| 799 |
+
plt.colorbar(hm, ax=ax, label="Zone Utilization")
|
| 800 |
+
ax.set_title("Zone Utilization Heatmap", fontsize=14)
|
| 801 |
+
plt.tight_layout()
|
| 802 |
+
plt.savefig(PLOTS_DIR / "zone_utilization_heatmap.png", dpi=150, facecolor=DARK_BG)
|
| 803 |
+
plt.close()
|
| 804 |
+
except Exception:
|
| 805 |
+
pass
|
| 806 |
+
|
| 807 |
+
# 4. Radar chart
|
| 808 |
+
try:
|
| 809 |
+
categories = ["Tardiness↓", "SLA↓", "Cycle Time↓", "Throughput↑", "Utilization"]
|
| 810 |
+
n_cats = len(categories)
|
| 811 |
+
angles = np.linspace(0, 2 * np.pi, n_cats, endpoint=False).tolist()
|
| 812 |
+
angles += angles[:1]
|
| 813 |
+
|
| 814 |
+
fig = plt.figure(figsize=(10, 10))
|
| 815 |
+
fig.patch.set_facecolor(DARK_BG)
|
| 816 |
+
ax = fig.add_subplot(111, polar=True)
|
| 817 |
+
ax.set_facecolor(DARK_AX)
|
| 818 |
+
|
| 819 |
+
for i, method in enumerate(methods[:6]):
|
| 820 |
+
mdf = df[df["method"] == method]
|
| 821 |
+
values = [
|
| 822 |
+
1 - float(np.clip(mdf["total_tardiness"].mean() / max(df["total_tardiness"].max(), 1e-9), 0, 1)),
|
| 823 |
+
1 - float(mdf["sla_breach_rate"].mean()),
|
| 824 |
+
1 - float(np.clip(mdf["avg_cycle_time"].mean() / df["avg_cycle_time"].max(), 0, 1)),
|
| 825 |
+
float(np.clip(mdf["throughput"].mean() / df["throughput"].max(), 0, 1)),
|
| 826 |
+
float(mdf["zone_utilization_avg"].mean()),
|
| 827 |
+
]
|
| 828 |
+
values += values[:1]
|
| 829 |
+
ax.plot(angles, values, color=COLORS[i], linewidth=2, label=method)
|
| 830 |
+
ax.fill(angles, values, color=COLORS[i], alpha=0.1)
|
| 831 |
+
|
| 832 |
+
ax.set_xticks(angles[:-1])
|
| 833 |
+
ax.set_xticklabels(categories, color=TEXT_COL)
|
| 834 |
+
ax.legend(loc="upper right", bbox_to_anchor=(1.3, 1.1))
|
| 835 |
+
ax.set_title("Performance Radar Chart", color=TEXT_COL, fontsize=14, pad=20)
|
| 836 |
+
plt.tight_layout()
|
| 837 |
+
plt.savefig(PLOTS_DIR / "radar_chart.png", dpi=150, facecolor=DARK_BG)
|
| 838 |
+
plt.close()
|
| 839 |
+
except Exception:
|
| 840 |
+
pass
|
| 841 |
+
|
| 842 |
+
# 5. Pareto front (makespan vs tardiness)
|
| 843 |
+
fig, ax = _dark_fig(figsize=(10, 7))
|
| 844 |
+
for method in methods:
|
| 845 |
+
mdf = df[df["method"] == method]
|
| 846 |
+
ax.scatter(
|
| 847 |
+
mdf["makespan"].mean(),
|
| 848 |
+
mdf["total_tardiness"].mean(),
|
| 849 |
+
color=method_colors[method],
|
| 850 |
+
s=120, label=method, zorder=5,
|
| 851 |
+
)
|
| 852 |
+
ax.set_title("Pareto Front: Makespan vs Tardiness", fontsize=14)
|
| 853 |
+
ax.set_xlabel("Mean Makespan (min)")
|
| 854 |
+
ax.set_ylabel("Mean Total Tardiness (min)")
|
| 855 |
+
ax.legend(facecolor=DARK_AX, labelcolor=TEXT_COL)
|
| 856 |
+
plt.tight_layout()
|
| 857 |
+
plt.savefig(PLOTS_DIR / "pareto_front.png", dpi=150, facecolor=DARK_BG)
|
| 858 |
+
plt.close()
|
| 859 |
+
|
| 860 |
+
# 6. Throughput comparison
|
| 861 |
+
fig, ax = _dark_fig(figsize=(12, 6))
|
| 862 |
+
thru_means = [df[df["method"] == m]["throughput"].mean() for m in methods]
|
| 863 |
+
ax.bar(methods, thru_means, color=[method_colors[m] for m in methods], alpha=0.85)
|
| 864 |
+
ax.set_title("Average Throughput (jobs/hour)", fontsize=14)
|
| 865 |
+
ax.set_ylabel("Throughput (jobs/hr)")
|
| 866 |
+
ax.tick_params(axis="x", rotation=35)
|
| 867 |
+
plt.tight_layout()
|
| 868 |
+
plt.savefig(PLOTS_DIR / "throughput_comparison.png", dpi=150, facecolor=DARK_BG)
|
| 869 |
+
plt.close()
|
| 870 |
+
|
| 871 |
+
logger.info("Generated plots in %s", PLOTS_DIR)
|
| 872 |
+
|
| 873 |
+
|
| 874 |
+
# ---------------------------------------------------------------------------
|
| 875 |
+
# Full evaluation pipeline
|
| 876 |
+
# ---------------------------------------------------------------------------
|
| 877 |
+
|
| 878 |
+
def run_full_evaluation(
|
| 879 |
+
seeds: Optional[List[int]] = None,
|
| 880 |
+
n_workers: int = 4,
|
| 881 |
+
) -> Dict[str, Any]:
|
| 882 |
+
"""Run complete evaluation: benchmark + stats + plots + JSON export."""
|
| 883 |
+
df = run_benchmark(seeds=seeds, n_workers=n_workers)
|
| 884 |
+
stats_results = run_statistical_analysis(df)
|
| 885 |
+
switching = run_switching_analysis(df)
|
| 886 |
+
export_benchmark_json(df)
|
| 887 |
+
generate_plots(df)
|
| 888 |
+
|
| 889 |
+
return {
|
| 890 |
+
"benchmark": df,
|
| 891 |
+
"stats": stats_results,
|
| 892 |
+
"switching": switching,
|
| 893 |
+
}
|
| 894 |
+
|
| 895 |
+
|
| 896 |
+
if __name__ == "__main__":
|
| 897 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
| 898 |
+
# Quick test with 20 seeds
|
| 899 |
+
run_full_evaluation(seeds=list(range(99000, 99020)), n_workers=2)
|
src/features.py
ADDED
|
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
features.py — Feature Extraction for Hybrid Warehouse Scheduler
|
| 3 |
+
|
| 4 |
+
Implements a stateful FeatureExtractor that computes 39 features split into:
|
| 5 |
+
- 32 scenario-level features describing system-wide state
|
| 6 |
+
(including 4 disruption-aware + 10 composition-adaptive novel features)
|
| 7 |
+
- 7 job-level features for per-job priority prediction
|
| 8 |
+
|
| 9 |
+
NEW in DAHS_2:
|
| 10 |
+
- get_feature_ranges() method: returns {feature_name: (min, max)} from training data
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import logging
|
| 17 |
+
import math
|
| 18 |
+
from collections import deque
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
# -------------------------------------------------------------------------
|
| 27 |
+
# Feature name lists (used for DataFrame column labeling)
|
| 28 |
+
# -------------------------------------------------------------------------
|
| 29 |
+
|
| 30 |
+
SCENARIO_FEATURE_NAMES: List[str] = [
|
| 31 |
+
"n_orders_in_system",
|
| 32 |
+
"n_express_orders_pct",
|
| 33 |
+
"avg_due_date_tightness",
|
| 34 |
+
"fraction_already_late",
|
| 35 |
+
"zone_utilization_avg",
|
| 36 |
+
"zone_utilization_std",
|
| 37 |
+
"bottleneck_zone",
|
| 38 |
+
"avg_remaining_proc_time",
|
| 39 |
+
"std_remaining_proc_time",
|
| 40 |
+
"throughput_last_30min",
|
| 41 |
+
"breakdown_flag",
|
| 42 |
+
"n_broken_stations",
|
| 43 |
+
"lunch_break_flag",
|
| 44 |
+
"surge_multiplier",
|
| 45 |
+
"batch_pending_flag",
|
| 46 |
+
"avg_priority_weight",
|
| 47 |
+
"max_tardiness_so_far",
|
| 48 |
+
"sla_breach_rate_current",
|
| 49 |
+
# Disruption-aware features (novel contribution)
|
| 50 |
+
"disruption_intensity",
|
| 51 |
+
"queue_imbalance",
|
| 52 |
+
"job_mix_entropy",
|
| 53 |
+
"time_pressure_ratio",
|
| 54 |
+
# Composition-adaptive features (novel contribution, DAHS 2.1)
|
| 55 |
+
"pct_type_A",
|
| 56 |
+
"pct_type_B",
|
| 57 |
+
"pct_type_C",
|
| 58 |
+
"pct_type_D",
|
| 59 |
+
"pct_type_E",
|
| 60 |
+
"count_type_A",
|
| 61 |
+
"count_type_B",
|
| 62 |
+
"count_type_C",
|
| 63 |
+
"count_type_D",
|
| 64 |
+
"count_type_E",
|
| 65 |
+
]
|
| 66 |
+
|
| 67 |
+
JOB_FEATURE_NAMES: List[str] = [
|
| 68 |
+
"job_type_encoded",
|
| 69 |
+
"proc_time_next_station",
|
| 70 |
+
"remaining_proc_time",
|
| 71 |
+
"time_to_due",
|
| 72 |
+
"time_in_system",
|
| 73 |
+
"critical_ratio",
|
| 74 |
+
"station_queue_at_next",
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
FEATURE_DESCRIPTIONS = {
|
| 78 |
+
"n_orders_in_system": "Total jobs currently in the system (waiting + processing)",
|
| 79 |
+
"n_express_orders_pct": "Fraction of waiting jobs that are express (type E)",
|
| 80 |
+
"avg_due_date_tightness": "Average (due_date - now) for waiting jobs",
|
| 81 |
+
"fraction_already_late": "Fraction of waiting jobs past their due date",
|
| 82 |
+
"zone_utilization_avg": "Average utilization across all 8 zones",
|
| 83 |
+
"zone_utilization_std": "Std deviation of zone utilization (imbalance indicator)",
|
| 84 |
+
"bottleneck_zone": "Utilization of the most-loaded zone",
|
| 85 |
+
"avg_remaining_proc_time": "Average remaining processing time for waiting jobs",
|
| 86 |
+
"std_remaining_proc_time": "Std deviation of remaining processing times",
|
| 87 |
+
"throughput_last_30min": "Jobs completed per minute in the last 30 minutes",
|
| 88 |
+
"breakdown_flag": "1 if any station is currently broken, else 0",
|
| 89 |
+
"n_broken_stations": "Number of stations currently under repair",
|
| 90 |
+
"lunch_break_flag": "1 if shift is currently in lunch break (t=300-360), else 0",
|
| 91 |
+
"surge_multiplier": "Current time-of-day arrival rate multiplier",
|
| 92 |
+
"batch_pending_flag": "1 if a truck batch arrival is imminent",
|
| 93 |
+
"avg_priority_weight": "Average priority weight of waiting jobs",
|
| 94 |
+
"max_tardiness_so_far": "Maximum job tardiness observed so far",
|
| 95 |
+
"sla_breach_rate_current": "Fraction of completed jobs that breached SLA",
|
| 96 |
+
"disruption_intensity": "[NOVEL] Composite disruption score: breakdowns + lunch + surge",
|
| 97 |
+
"queue_imbalance": "[NOVEL] Coefficient of variation of queue sizes across zones",
|
| 98 |
+
"job_mix_entropy": "[NOVEL] Shannon entropy of job-type distribution in queue",
|
| 99 |
+
"time_pressure_ratio": "[NOVEL] Fraction of waiting jobs with Critical Ratio < 1",
|
| 100 |
+
"pct_type_A": "[NOVEL] Fraction of waiting jobs of type A (standard)",
|
| 101 |
+
"pct_type_B": "[NOVEL] Fraction of waiting jobs of type B (picking-intensive)",
|
| 102 |
+
"pct_type_C": "[NOVEL] Fraction of waiting jobs of type C (value-add)",
|
| 103 |
+
"pct_type_D": "[NOVEL] Fraction of waiting jobs of type D (complex/bulk)",
|
| 104 |
+
"pct_type_E": "[NOVEL] Fraction of waiting jobs of type E (express)",
|
| 105 |
+
"count_type_A": "[NOVEL] Absolute count of waiting type-A jobs",
|
| 106 |
+
"count_type_B": "[NOVEL] Absolute count of waiting type-B jobs",
|
| 107 |
+
"count_type_C": "[NOVEL] Absolute count of waiting type-C jobs",
|
| 108 |
+
"count_type_D": "[NOVEL] Absolute count of waiting type-D jobs",
|
| 109 |
+
"count_type_E": "[NOVEL] Absolute count of waiting type-E jobs",
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
# Job type → integer encoding
|
| 113 |
+
_JOB_TYPE_ENC: Dict[str, int] = {"A": 0, "B": 1, "C": 2, "D": 3, "E": 4}
|
| 114 |
+
|
| 115 |
+
# Job type → priority weight (mirrors simulator definitions)
|
| 116 |
+
_JOB_PRIORITY_WEIGHT: Dict[str, float] = {
|
| 117 |
+
"A": 2.0, "B": 1.5, "C": 1.0, "D": 0.8, "E": 3.0
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class FeatureExtractor:
|
| 122 |
+
"""Stateful extractor that maintains running statistics across events.
|
| 123 |
+
|
| 124 |
+
Call ``update(event_type, data)`` as events occur during simulation,
|
| 125 |
+
then call ``extract_scenario_features`` or ``extract_job_features``
|
| 126 |
+
to obtain the feature vectors.
|
| 127 |
+
|
| 128 |
+
NEW in DAHS_2:
|
| 129 |
+
- get_feature_ranges(): returns {feature_name: (min, max)} from a training DataFrame
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
# Window size for throughput tracking (minutes)
|
| 133 |
+
THROUGHPUT_WINDOW = 30.0
|
| 134 |
+
|
| 135 |
+
def __init__(self) -> None:
|
| 136 |
+
# Circular buffer of (timestamp, job_id) for throughput window
|
| 137 |
+
self._completion_times: deque = deque()
|
| 138 |
+
# Batch pending flag set externally when a truck batch is imminent
|
| 139 |
+
self.batch_pending: bool = False
|
| 140 |
+
# Stored feature ranges for OOD detection (set after training)
|
| 141 |
+
self._feature_ranges: Optional[Dict[str, Tuple[float, float]]] = None
|
| 142 |
+
# Metadata loaded alongside the ranges (run hash etc.) — used by the
|
| 143 |
+
# selector loader to detect stale artifacts.
|
| 144 |
+
self._feature_ranges_meta: Dict[str, Any] = {}
|
| 145 |
+
|
| 146 |
+
# ------------------------------------------------------------------
|
| 147 |
+
# Event update
|
| 148 |
+
# ------------------------------------------------------------------
|
| 149 |
+
|
| 150 |
+
def update(self, event_type: str, data: Dict[str, Any]) -> None:
|
| 151 |
+
"""Update running statistics on job events."""
|
| 152 |
+
if event_type == "job_complete":
|
| 153 |
+
self._completion_times.append(data.get("timestamp", 0.0))
|
| 154 |
+
|
| 155 |
+
# ------------------------------------------------------------------
|
| 156 |
+
# Scenario-level features (22)
|
| 157 |
+
# ------------------------------------------------------------------
|
| 158 |
+
|
| 159 |
+
def extract_scenario_features(self, sim_state: Dict[str, Any]) -> np.ndarray:
|
| 160 |
+
"""Extract 32 scenario-level features from a system state snapshot.
|
| 161 |
+
|
| 162 |
+
22 system-state features (F1-F22, including 4 disruption-aware novel)
|
| 163 |
+
+ 10 composition-adaptive features (F23-F32, novel in DAHS 2.1).
|
| 164 |
+
|
| 165 |
+
Parameters
|
| 166 |
+
----------
|
| 167 |
+
sim_state : dict
|
| 168 |
+
Output of ``WarehouseSimulator.get_state_snapshot()``.
|
| 169 |
+
|
| 170 |
+
Returns
|
| 171 |
+
-------
|
| 172 |
+
np.ndarray of shape (32,)
|
| 173 |
+
"""
|
| 174 |
+
now: float = sim_state.get("current_time", 0.0)
|
| 175 |
+
waiting_jobs: List[Any] = sim_state.get("waiting_jobs", [])
|
| 176 |
+
completed_jobs: List[Any] = sim_state.get("completed_jobs", [])
|
| 177 |
+
queue_sizes: Dict[int, int] = sim_state.get("queue_sizes", {})
|
| 178 |
+
zone_util: Dict[int, float] = sim_state.get("zone_utilization", {})
|
| 179 |
+
n_broken: int = sim_state.get("n_broken_stations", 0)
|
| 180 |
+
lunch: bool = sim_state.get("lunch_active", False)
|
| 181 |
+
surge: float = sim_state.get("surge_multiplier", 1.0)
|
| 182 |
+
|
| 183 |
+
# F1: n_orders_in_system
|
| 184 |
+
n_in_system = float(sim_state.get("n_orders_in_system", 0))
|
| 185 |
+
|
| 186 |
+
# F2: n_express_orders_pct
|
| 187 |
+
n_express = sum(1 for j in waiting_jobs if j.job_type == "E")
|
| 188 |
+
n_express_pct = n_express / max(1.0, n_in_system)
|
| 189 |
+
|
| 190 |
+
# F3: avg_due_date_tightness = avg(due_date - now) for waiting jobs
|
| 191 |
+
if waiting_jobs:
|
| 192 |
+
tightness = float(np.mean([j.due_date - now for j in waiting_jobs]))
|
| 193 |
+
else:
|
| 194 |
+
tightness = 999.0
|
| 195 |
+
|
| 196 |
+
# F4: fraction_already_late
|
| 197 |
+
if waiting_jobs:
|
| 198 |
+
frac_late = sum(1 for j in waiting_jobs if j.due_date < now) / len(waiting_jobs)
|
| 199 |
+
else:
|
| 200 |
+
frac_late = 0.0
|
| 201 |
+
|
| 202 |
+
# F5/F6: zone utilization avg and std
|
| 203 |
+
util_vals = list(zone_util.values())
|
| 204 |
+
util_avg = float(np.mean(util_vals)) if util_vals else 0.0
|
| 205 |
+
util_std = float(np.std(util_vals)) if util_vals else 0.0
|
| 206 |
+
|
| 207 |
+
# F7: bottleneck_zone (utilization value of the most-loaded zone)
|
| 208 |
+
# Bug fix from DAHS_1: use max(zone_util.values()) NOT zone_id
|
| 209 |
+
if zone_util:
|
| 210 |
+
bottleneck = float(max(zone_util.values()))
|
| 211 |
+
else:
|
| 212 |
+
bottleneck = 0.0
|
| 213 |
+
|
| 214 |
+
# F8/F9: avg and std remaining proc time for waiting jobs
|
| 215 |
+
rem_times = [j.remaining_proc_time() for j in waiting_jobs]
|
| 216 |
+
avg_rem = float(np.mean(rem_times)) if rem_times else 0.0
|
| 217 |
+
std_rem = float(np.std(rem_times)) if rem_times else 0.0
|
| 218 |
+
|
| 219 |
+
# F10: throughput in last 30 min (completions per minute)
|
| 220 |
+
cutoff = now - self.THROUGHPUT_WINDOW
|
| 221 |
+
while self._completion_times and self._completion_times[0] < cutoff:
|
| 222 |
+
self._completion_times.popleft()
|
| 223 |
+
throughput_30 = len(self._completion_times) / self.THROUGHPUT_WINDOW
|
| 224 |
+
|
| 225 |
+
# F11: breakdown_flag
|
| 226 |
+
breakdown_flag = 1.0 if n_broken > 0 else 0.0
|
| 227 |
+
|
| 228 |
+
# F12: n_broken_stations
|
| 229 |
+
n_broken_f = float(n_broken)
|
| 230 |
+
|
| 231 |
+
# F13: lunch_break_flag
|
| 232 |
+
lunch_flag = 1.0 if lunch else 0.0
|
| 233 |
+
|
| 234 |
+
# F14: surge_multiplier
|
| 235 |
+
surge_f = float(surge)
|
| 236 |
+
|
| 237 |
+
# F15: batch_pending_flag
|
| 238 |
+
batch_flag = 1.0 if self.batch_pending else 0.0
|
| 239 |
+
|
| 240 |
+
# F16: avg_priority_weight
|
| 241 |
+
if waiting_jobs:
|
| 242 |
+
avg_prio_w = float(np.mean([
|
| 243 |
+
_JOB_PRIORITY_WEIGHT.get(j.job_type, 1.0) for j in waiting_jobs
|
| 244 |
+
]))
|
| 245 |
+
else:
|
| 246 |
+
avg_prio_w = 1.0
|
| 247 |
+
|
| 248 |
+
# F17: max_tardiness_so_far
|
| 249 |
+
if completed_jobs:
|
| 250 |
+
max_tard = float(max(
|
| 251 |
+
max(0.0, j.completion_time - j.due_date) for j in completed_jobs
|
| 252 |
+
))
|
| 253 |
+
else:
|
| 254 |
+
max_tard = 0.0
|
| 255 |
+
|
| 256 |
+
# F18: sla_breach_rate_current
|
| 257 |
+
if completed_jobs:
|
| 258 |
+
breach_rate = sum(
|
| 259 |
+
1 for j in completed_jobs if j.completion_time > j.due_date
|
| 260 |
+
) / len(completed_jobs)
|
| 261 |
+
else:
|
| 262 |
+
breach_rate = 0.0
|
| 263 |
+
|
| 264 |
+
# F19: disruption_intensity — composite disruption score [0, 1]
|
| 265 |
+
breakdown_severity = min(1.0, n_broken / 5.0)
|
| 266 |
+
lunch_severity = 1.0 if lunch else 0.0
|
| 267 |
+
surge_deviation = abs(surge - 1.0)
|
| 268 |
+
disruption_intensity = 0.5 * breakdown_severity + 0.25 * lunch_severity + 0.25 * surge_deviation
|
| 269 |
+
|
| 270 |
+
# F20: queue_imbalance — coefficient of variation of queue sizes
|
| 271 |
+
# Bug fix: guard with mean > 1e-6 (not > 0)
|
| 272 |
+
q_vals = list(queue_sizes.values())
|
| 273 |
+
if q_vals and np.mean(q_vals) > 1e-6:
|
| 274 |
+
queue_imbalance = float(min(np.std(q_vals) / np.mean(q_vals), 10.0))
|
| 275 |
+
else:
|
| 276 |
+
queue_imbalance = 0.0
|
| 277 |
+
|
| 278 |
+
# F21: job_mix_entropy — Shannon entropy of job type distribution in queue
|
| 279 |
+
if waiting_jobs:
|
| 280 |
+
type_counts: Dict[str, int] = {}
|
| 281 |
+
for j in waiting_jobs:
|
| 282 |
+
type_counts[j.job_type] = type_counts.get(j.job_type, 0) + 1
|
| 283 |
+
total_w = len(waiting_jobs)
|
| 284 |
+
job_mix_entropy = 0.0
|
| 285 |
+
for cnt in type_counts.values():
|
| 286 |
+
p = cnt / total_w
|
| 287 |
+
if p > 0:
|
| 288 |
+
job_mix_entropy -= p * math.log2(p)
|
| 289 |
+
else:
|
| 290 |
+
job_mix_entropy = 0.0
|
| 291 |
+
|
| 292 |
+
# F22: time_pressure_ratio — fraction of waiting jobs with CR < 1
|
| 293 |
+
if waiting_jobs:
|
| 294 |
+
n_under_pressure = 0
|
| 295 |
+
for j in waiting_jobs:
|
| 296 |
+
rem = j.remaining_proc_time()
|
| 297 |
+
ttd = j.due_date - now
|
| 298 |
+
cr = ttd / max(rem, 0.001) if rem > 0 else 999.0
|
| 299 |
+
if cr < 1.0:
|
| 300 |
+
n_under_pressure += 1
|
| 301 |
+
time_pressure_ratio = n_under_pressure / len(waiting_jobs)
|
| 302 |
+
else:
|
| 303 |
+
time_pressure_ratio = 0.0
|
| 304 |
+
|
| 305 |
+
# F23-F32: composition-adaptive features (per-type % and absolute counts)
|
| 306 |
+
# These give the selector explicit, non-lossy signal about the current
|
| 307 |
+
# batch composition — crucial for heuristic adaptation.
|
| 308 |
+
type_counts: Dict[str, int] = {"A": 0, "B": 0, "C": 0, "D": 0, "E": 0}
|
| 309 |
+
for j in waiting_jobs:
|
| 310 |
+
if j.job_type in type_counts:
|
| 311 |
+
type_counts[j.job_type] += 1
|
| 312 |
+
total_w = max(len(waiting_jobs), 1)
|
| 313 |
+
pct_A = type_counts["A"] / total_w if waiting_jobs else 0.0
|
| 314 |
+
pct_B = type_counts["B"] / total_w if waiting_jobs else 0.0
|
| 315 |
+
pct_C = type_counts["C"] / total_w if waiting_jobs else 0.0
|
| 316 |
+
pct_D = type_counts["D"] / total_w if waiting_jobs else 0.0
|
| 317 |
+
pct_E = type_counts["E"] / total_w if waiting_jobs else 0.0
|
| 318 |
+
|
| 319 |
+
features = np.array([
|
| 320 |
+
n_in_system, # F1
|
| 321 |
+
n_express_pct, # F2
|
| 322 |
+
tightness, # F3
|
| 323 |
+
frac_late, # F4
|
| 324 |
+
util_avg, # F5
|
| 325 |
+
util_std, # F6
|
| 326 |
+
bottleneck, # F7
|
| 327 |
+
avg_rem, # F8
|
| 328 |
+
std_rem, # F9
|
| 329 |
+
throughput_30, # F10
|
| 330 |
+
breakdown_flag, # F11
|
| 331 |
+
n_broken_f, # F12
|
| 332 |
+
lunch_flag, # F13
|
| 333 |
+
surge_f, # F14
|
| 334 |
+
batch_flag, # F15
|
| 335 |
+
avg_prio_w, # F16
|
| 336 |
+
max_tard, # F17
|
| 337 |
+
breach_rate, # F18
|
| 338 |
+
disruption_intensity, # F19 (novel)
|
| 339 |
+
queue_imbalance, # F20 (novel)
|
| 340 |
+
job_mix_entropy, # F21 (novel)
|
| 341 |
+
time_pressure_ratio, # F22 (novel)
|
| 342 |
+
pct_A, # F23 (novel)
|
| 343 |
+
pct_B, # F24 (novel)
|
| 344 |
+
pct_C, # F25 (novel)
|
| 345 |
+
pct_D, # F26 (novel)
|
| 346 |
+
pct_E, # F27 (novel)
|
| 347 |
+
float(type_counts["A"]),# F28 (novel)
|
| 348 |
+
float(type_counts["B"]),# F29 (novel)
|
| 349 |
+
float(type_counts["C"]),# F30 (novel)
|
| 350 |
+
float(type_counts["D"]),# F31 (novel)
|
| 351 |
+
float(type_counts["E"]),# F32 (novel)
|
| 352 |
+
], dtype=np.float64)
|
| 353 |
+
|
| 354 |
+
# Sanitize: replace NaN/inf with safe values (training pipeline bug fix)
|
| 355 |
+
features = np.nan_to_num(features, nan=0.0, posinf=999.0, neginf=-999.0)
|
| 356 |
+
|
| 357 |
+
return features.astype(np.float32)
|
| 358 |
+
|
| 359 |
+
# ------------------------------------------------------------------
|
| 360 |
+
# Job-level features (7)
|
| 361 |
+
# ------------------------------------------------------------------
|
| 362 |
+
|
| 363 |
+
def extract_job_features(self, job: Any, sim_state: Dict[str, Any]) -> np.ndarray:
|
| 364 |
+
"""Extract 7 job-level features for priority prediction."""
|
| 365 |
+
now: float = sim_state.get("current_time", 0.0)
|
| 366 |
+
queue_sizes: Dict[int, int] = sim_state.get("queue_sizes", {})
|
| 367 |
+
|
| 368 |
+
jt_enc = float(_JOB_TYPE_ENC.get(job.job_type, 0))
|
| 369 |
+
|
| 370 |
+
if not job.is_complete:
|
| 371 |
+
next_op = job.operations[job.current_op_idx]
|
| 372 |
+
proc_next = float(next_op.nominal_proc_time)
|
| 373 |
+
else:
|
| 374 |
+
proc_next = 0.0
|
| 375 |
+
|
| 376 |
+
rem_proc = float(job.remaining_proc_time())
|
| 377 |
+
time_to_due = float(job.due_date - now)
|
| 378 |
+
time_in_sys = float(now - job.arrival_time)
|
| 379 |
+
|
| 380 |
+
if rem_proc > 0:
|
| 381 |
+
cr = time_to_due / rem_proc
|
| 382 |
+
else:
|
| 383 |
+
cr = 999.0 # large finite value, safe for ML models
|
| 384 |
+
|
| 385 |
+
if not job.is_complete:
|
| 386 |
+
next_zone = job.operations[job.current_op_idx].zone_id
|
| 387 |
+
queue_at_next = float(queue_sizes.get(next_zone, 0))
|
| 388 |
+
else:
|
| 389 |
+
queue_at_next = 0.0
|
| 390 |
+
|
| 391 |
+
features = np.array([
|
| 392 |
+
jt_enc,
|
| 393 |
+
proc_next,
|
| 394 |
+
rem_proc,
|
| 395 |
+
time_to_due,
|
| 396 |
+
time_in_sys,
|
| 397 |
+
cr,
|
| 398 |
+
queue_at_next,
|
| 399 |
+
], dtype=np.float32)
|
| 400 |
+
|
| 401 |
+
return features
|
| 402 |
+
|
| 403 |
+
# ------------------------------------------------------------------
|
| 404 |
+
# Feature names
|
| 405 |
+
# ------------------------------------------------------------------
|
| 406 |
+
|
| 407 |
+
def get_feature_names(self, level: str = "scenario") -> List[str]:
|
| 408 |
+
"""Return the ordered list of feature names."""
|
| 409 |
+
if level == "scenario":
|
| 410 |
+
return SCENARIO_FEATURE_NAMES
|
| 411 |
+
elif level == "job":
|
| 412 |
+
return JOB_FEATURE_NAMES
|
| 413 |
+
elif level == "all":
|
| 414 |
+
return SCENARIO_FEATURE_NAMES + JOB_FEATURE_NAMES
|
| 415 |
+
else:
|
| 416 |
+
raise ValueError(f"Unknown level: {level!r}. Use 'scenario', 'job', or 'all'.")
|
| 417 |
+
|
| 418 |
+
# ------------------------------------------------------------------
|
| 419 |
+
# NEW in DAHS_2: Feature ranges for OOD detection
|
| 420 |
+
# ------------------------------------------------------------------
|
| 421 |
+
|
| 422 |
+
def get_feature_ranges(
|
| 423 |
+
self,
|
| 424 |
+
X_train: Optional[np.ndarray] = None,
|
| 425 |
+
feature_names: Optional[List[str]] = None,
|
| 426 |
+
) -> Dict[str, Tuple[float, float]]:
|
| 427 |
+
"""Compute {feature_name: (min, max)} from training data.
|
| 428 |
+
|
| 429 |
+
If X_train is None, returns stored ranges (set by set_feature_ranges()).
|
| 430 |
+
|
| 431 |
+
Parameters
|
| 432 |
+
----------
|
| 433 |
+
X_train : np.ndarray of shape (n_samples, 22)
|
| 434 |
+
Training feature matrix. If None, returns cached ranges.
|
| 435 |
+
feature_names : list of str, optional
|
| 436 |
+
Column names. Defaults to SCENARIO_FEATURE_NAMES.
|
| 437 |
+
|
| 438 |
+
Returns
|
| 439 |
+
-------
|
| 440 |
+
dict mapping feature_name -> (min_val, max_val)
|
| 441 |
+
"""
|
| 442 |
+
if X_train is None:
|
| 443 |
+
if self._feature_ranges is None:
|
| 444 |
+
raise ValueError("No training data provided and no cached feature ranges.")
|
| 445 |
+
return self._feature_ranges
|
| 446 |
+
|
| 447 |
+
names = feature_names or SCENARIO_FEATURE_NAMES
|
| 448 |
+
ranges = {}
|
| 449 |
+
for i, name in enumerate(names):
|
| 450 |
+
if i < X_train.shape[1]:
|
| 451 |
+
ranges[name] = (float(X_train[:, i].min()), float(X_train[:, i].max()))
|
| 452 |
+
self._feature_ranges = ranges
|
| 453 |
+
return ranges
|
| 454 |
+
|
| 455 |
+
def set_feature_ranges(self, ranges: Dict[str, Tuple[float, float]]) -> None:
|
| 456 |
+
"""Set feature ranges for OOD detection (loaded from JSON artifact)."""
|
| 457 |
+
self._feature_ranges = ranges
|
| 458 |
+
|
| 459 |
+
def load_feature_ranges(self, json_path: "Union[Path, str]") -> Dict[str, Tuple[float, float]]:
|
| 460 |
+
"""Load feature ranges from a JSON file saved by train_selector.py.
|
| 461 |
+
|
| 462 |
+
Accepts both the legacy flat format ({feature_name: [min, max]}) and
|
| 463 |
+
the wrapped format ({"_meta": {...}, "ranges": {feature_name: [...]}}).
|
| 464 |
+
Stores any meta payload on `self._feature_ranges_meta` so callers can
|
| 465 |
+
verify the artifact was produced in the same training run as the model.
|
| 466 |
+
"""
|
| 467 |
+
with open(json_path, "r") as f:
|
| 468 |
+
data = json.load(f)
|
| 469 |
+
if isinstance(data, dict) and "ranges" in data:
|
| 470 |
+
self._feature_ranges_meta = data.get("_meta", {})
|
| 471 |
+
raw = data["ranges"]
|
| 472 |
+
else:
|
| 473 |
+
self._feature_ranges_meta = {}
|
| 474 |
+
raw = data
|
| 475 |
+
ranges = {k: (v[0], v[1]) for k, v in raw.items()}
|
| 476 |
+
self._feature_ranges = ranges
|
| 477 |
+
return ranges
|
| 478 |
+
|
| 479 |
+
def is_out_of_distribution(
|
| 480 |
+
self,
|
| 481 |
+
features: np.ndarray,
|
| 482 |
+
tolerance: float = 0.10,
|
| 483 |
+
) -> bool:
|
| 484 |
+
"""Check if any feature falls outside training range ±10%.
|
| 485 |
+
|
| 486 |
+
Parameters
|
| 487 |
+
----------
|
| 488 |
+
features : np.ndarray of shape (22,)
|
| 489 |
+
Scenario features to check.
|
| 490 |
+
tolerance : float
|
| 491 |
+
Fractional tolerance beyond training range (default 10%).
|
| 492 |
+
|
| 493 |
+
Returns
|
| 494 |
+
-------
|
| 495 |
+
bool: True if OOD
|
| 496 |
+
"""
|
| 497 |
+
if self._feature_ranges is None:
|
| 498 |
+
return False # no ranges loaded → assume in-distribution
|
| 499 |
+
|
| 500 |
+
for i, name in enumerate(SCENARIO_FEATURE_NAMES):
|
| 501 |
+
if name not in self._feature_ranges:
|
| 502 |
+
continue
|
| 503 |
+
lo, hi = self._feature_ranges[name]
|
| 504 |
+
val = float(features[i])
|
| 505 |
+
span = max(hi - lo, 1e-6)
|
| 506 |
+
if val < lo - tolerance * span or val > hi + tolerance * span:
|
| 507 |
+
return True
|
| 508 |
+
return False
|
src/heuristics.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
heuristics.py — Dispatch Heuristics for Warehouse Job Shop Scheduling
|
| 3 |
+
|
| 4 |
+
Provides six industry-standard dispatch rules plus stub wrappers for
|
| 5 |
+
ML-driven hybrid dispatch (filled in by hybrid_scheduler.py).
|
| 6 |
+
|
| 7 |
+
Academic References
|
| 8 |
+
-------------------
|
| 9 |
+
- FIFO (First-In First-Out):
|
| 10 |
+
Standard queue discipline; no specific citation needed.
|
| 11 |
+
|
| 12 |
+
- Priority-EDD (Earliest Due Date):
|
| 13 |
+
Jackson, J.R. (1955). Scheduling a production line to minimize
|
| 14 |
+
maximum tardiness. Management Research Project Report 43, UCLA.
|
| 15 |
+
|
| 16 |
+
- Critical Ratio (CR):
|
| 17 |
+
Conway, R.W., Maxwell, W.L., & Miller, L.W. (1967). Theory of
|
| 18 |
+
Scheduling. Addison-Wesley.
|
| 19 |
+
Also: Pinedo, M.L. (2016). Scheduling: Theory, Algorithms, and
|
| 20 |
+
Systems. Springer (5th ed.). doi:10.1007/978-3-319-26580-3.
|
| 21 |
+
|
| 22 |
+
- ATC (Apparent Tardiness Cost):
|
| 23 |
+
Vepsalainen, A.P.J. & Morton, T.E. (1987). Priority rules for job
|
| 24 |
+
shops with weighted tardiness costs. Management Science, 33(8),
|
| 25 |
+
1035-1047. doi:10.1287/mnsc.33.8.1035.
|
| 26 |
+
|
| 27 |
+
- WSPT (Weighted Shortest Processing Time):
|
| 28 |
+
Smith, W.E. (1956). Various optimizers for single-stage production.
|
| 29 |
+
Naval Research Logistics Quarterly, 3(1-2), 59-66.
|
| 30 |
+
doi:10.1002/nav.3800030106. [Optimal for weighted completion time.]
|
| 31 |
+
|
| 32 |
+
- Slack (Minimum Slack):
|
| 33 |
+
Pinedo, M.L. (2016). Scheduling: Theory, Algorithms, and Systems.
|
| 34 |
+
Springer (5th ed.). doi:10.1007/978-3-319-26580-3.
|
| 35 |
+
|
| 36 |
+
Hyper-heuristic framework (ML selection over these 6 rules):
|
| 37 |
+
Burke, E.K. et al. (2013). Hyper-heuristics: A survey of the state
|
| 38 |
+
of the art. JORS, 64(12), 1695-1724. doi:10.1057/jors.2013.71.
|
| 39 |
+
Cowling, P., Kendall, G., & Soubeiga, E. (2001). A hyperheuristic
|
| 40 |
+
approach to scheduling a sales summit. PATAT 2000, LNCS 2079.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
from __future__ import annotations
|
| 44 |
+
|
| 45 |
+
import math
|
| 46 |
+
import logging
|
| 47 |
+
from typing import Any, Dict, List
|
| 48 |
+
|
| 49 |
+
logger = logging.getLogger(__name__)
|
| 50 |
+
|
| 51 |
+
# Priority class mapping (higher number = higher priority in dispatch)
|
| 52 |
+
_PRIORITY_CLASS: Dict[str, int] = {
|
| 53 |
+
"E": 4, # Express — highest
|
| 54 |
+
"A": 3,
|
| 55 |
+
"C": 2,
|
| 56 |
+
"B": 1,
|
| 57 |
+
"D": 0, # Deferred — lowest
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def get_priority_class(job_type: str) -> int:
|
| 62 |
+
"""Return numeric priority class for a job type string."""
|
| 63 |
+
return _PRIORITY_CLASS.get(job_type, 1)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def compute_critical_ratio(job: Any, current_time: float) -> float:
|
| 67 |
+
"""Compute the Critical Ratio for a job.
|
| 68 |
+
|
| 69 |
+
CR = time_remaining_to_due / remaining_processing_time
|
| 70 |
+
|
| 71 |
+
A CR < 1 means the job is behind schedule. Negative CR means already late.
|
| 72 |
+
CR = 999.0 is returned when remaining_proc = 0 (done job — large finite value).
|
| 73 |
+
"""
|
| 74 |
+
time_to_due = job.due_date - current_time
|
| 75 |
+
remaining_proc = job.remaining_proc_time()
|
| 76 |
+
|
| 77 |
+
if remaining_proc <= 0:
|
| 78 |
+
return 999.0 # done job — large finite value, sorts last in ascending CR dispatch
|
| 79 |
+
if time_to_due <= 0:
|
| 80 |
+
return time_to_due / remaining_proc # negative CR = already late
|
| 81 |
+
|
| 82 |
+
return time_to_due / remaining_proc
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# ---------------------------------------------------------------------------
|
| 86 |
+
# Baseline heuristics
|
| 87 |
+
# ---------------------------------------------------------------------------
|
| 88 |
+
|
| 89 |
+
# Ref: Standard queue discipline — no specific academic citation required.
|
| 90 |
+
def fifo_dispatch(jobs: List[Any], current_time: float, zone_id: int) -> List[Any]:
|
| 91 |
+
"""First-In First-Out dispatch: sort by arrival_time ascending."""
|
| 92 |
+
return sorted(jobs, key=lambda j: j.arrival_time)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# Ref: Jackson (1955), "Scheduling a production line to minimize maximum tardiness",
|
| 96 |
+
# Management Research Project Report 43, UCLA.
|
| 97 |
+
# Extended with priority classes for multi-tier fulfillment environments.
|
| 98 |
+
def priority_edd_dispatch(jobs: List[Any], current_time: float, zone_id: int) -> List[Any]:
|
| 99 |
+
"""Priority-EDD dispatch: sort by (priority_class DESC, due_date ASC)."""
|
| 100 |
+
return sorted(
|
| 101 |
+
jobs,
|
| 102 |
+
key=lambda j: (-get_priority_class(j.job_type), j.due_date),
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# Ref: Conway et al. (1967), "Theory of Scheduling", Addison-Wesley.
|
| 107 |
+
# Also: Pinedo (2016), "Scheduling: Theory, Algorithms, and Systems", Springer 5th ed.
|
| 108 |
+
def critical_ratio_dispatch(jobs: List[Any], current_time: float, zone_id: int) -> List[Any]:
|
| 109 |
+
"""Critical Ratio dispatch: sort by CR ascending (most urgent first)."""
|
| 110 |
+
return sorted(jobs, key=lambda j: compute_critical_ratio(j, current_time))
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# Priority weight mapping (mirrors simulator definitions)
|
| 114 |
+
_PRIORITY_WEIGHT: Dict[str, float] = {
|
| 115 |
+
"A": 2.0, "B": 1.5, "C": 1.0, "D": 0.8, "E": 3.0,
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# Ref: Vepsalainen, A.P.J. & Morton, T.E. (1987). Priority rules for job shops
|
| 120 |
+
# with weighted tardiness costs. Management Science, 33(8), 1035-1047.
|
| 121 |
+
# doi:10.1287/mnsc.33.8.1035
|
| 122 |
+
def atc_dispatch(jobs: List[Any], current_time: float, zone_id: int) -> List[Any]:
|
| 123 |
+
"""Apparent Tardiness Cost (ATC) dispatch.
|
| 124 |
+
|
| 125 |
+
ATC_i = (w_i / p_i) * exp(-max(0, d_i - p_i - t) / (K * p_avg))
|
| 126 |
+
|
| 127 |
+
where K is the look-ahead parameter (K=2.0), p_avg is the average
|
| 128 |
+
remaining processing time across waiting jobs.
|
| 129 |
+
Higher ATC score → dispatch sooner.
|
| 130 |
+
|
| 131 |
+
Reference: Vepsalainen & Morton (1987), Management Science 33(8):1035-1047.
|
| 132 |
+
"""
|
| 133 |
+
if not jobs:
|
| 134 |
+
return jobs
|
| 135 |
+
|
| 136 |
+
p_vals = [max(j.remaining_proc_time(), 0.001) for j in jobs]
|
| 137 |
+
p_avg = sum(p_vals) / len(p_vals)
|
| 138 |
+
K = 2.0 # look-ahead parameter
|
| 139 |
+
|
| 140 |
+
def _atc_score(job: Any) -> float:
|
| 141 |
+
w = _PRIORITY_WEIGHT.get(job.job_type, 1.0)
|
| 142 |
+
p = max(job.remaining_proc_time(), 0.001)
|
| 143 |
+
slack = job.due_date - p - current_time
|
| 144 |
+
urgency = math.exp(-max(0.0, slack) / max(K * p_avg, 0.001))
|
| 145 |
+
return (w / p) * urgency
|
| 146 |
+
|
| 147 |
+
return sorted(jobs, key=_atc_score, reverse=True)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# Ref: Smith, W.E. (1956). Various optimizers for single-stage production.
|
| 151 |
+
# Naval Research Logistics Quarterly, 3(1-2), 59-66.
|
| 152 |
+
# doi:10.1002/nav.3800030106
|
| 153 |
+
# [Proven optimal for minimizing weighted completion time on a single machine.]
|
| 154 |
+
def wspt_dispatch(jobs: List[Any], current_time: float, zone_id: int) -> List[Any]:
|
| 155 |
+
"""Weighted Shortest Processing Time (WSPT) dispatch.
|
| 156 |
+
|
| 157 |
+
Sort by w_i / p_i descending — prioritizes jobs with high
|
| 158 |
+
priority-to-processing-time ratio.
|
| 159 |
+
|
| 160 |
+
Reference: Smith (1956), Naval Research Logistics Quarterly 3(1-2):59-66.
|
| 161 |
+
"""
|
| 162 |
+
def _wspt_score(job: Any) -> float:
|
| 163 |
+
w = _PRIORITY_WEIGHT.get(job.job_type, 1.0)
|
| 164 |
+
p = max(job.remaining_proc_time(), 0.001)
|
| 165 |
+
return w / p
|
| 166 |
+
|
| 167 |
+
return sorted(jobs, key=_wspt_score, reverse=True)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# Ref: Pinedo, M.L. (2016). Scheduling: Theory, Algorithms, and Systems.
|
| 171 |
+
# Springer, 5th edition. doi:10.1007/978-3-319-26580-3.
|
| 172 |
+
def slack_dispatch(jobs: List[Any], current_time: float, zone_id: int) -> List[Any]:
|
| 173 |
+
"""Slack-based dispatch: sort by remaining slack ascending.
|
| 174 |
+
|
| 175 |
+
Slack = (due_date - current_time) - remaining_proc_time
|
| 176 |
+
Lower slack → less margin → dispatch sooner.
|
| 177 |
+
|
| 178 |
+
Reference: Pinedo (2016), Scheduling: Theory, Algorithms, and Systems.
|
| 179 |
+
"""
|
| 180 |
+
def _slack(job: Any) -> float:
|
| 181 |
+
return (job.due_date - current_time) - job.remaining_proc_time()
|
| 182 |
+
|
| 183 |
+
return sorted(jobs, key=_slack)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
# Dispatch map for convenience
|
| 187 |
+
DISPATCH_MAP = {
|
| 188 |
+
"fifo": fifo_dispatch,
|
| 189 |
+
"priority_edd": priority_edd_dispatch,
|
| 190 |
+
"critical_ratio": critical_ratio_dispatch,
|
| 191 |
+
"atc": atc_dispatch,
|
| 192 |
+
"wspt": wspt_dispatch,
|
| 193 |
+
"slack": slack_dispatch,
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
ALL_HEURISTICS = list(DISPATCH_MAP.keys())
|
| 197 |
+
HEURISTIC_LABELS = ["FIFO", "Priority-EDD", "Critical-Ratio", "ATC", "WSPT", "Slack"]
|
src/hybrid_scheduler.py
ADDED
|
@@ -0,0 +1,865 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
hybrid_scheduler.py — Batch-wise ML Hybrid Scheduler with Guardrails (DAHS_2)
|
| 3 |
+
|
| 4 |
+
NEW architecture vs DAHS_1:
|
| 5 |
+
- BatchwiseSelector: re-evaluates every 15 min OR on disruption events
|
| 6 |
+
- Hysteresis: only switches if >15% more confident
|
| 7 |
+
- Edge case guardrails: trivial load, overload, OOD detection
|
| 8 |
+
- Starvation prevention: force-promote jobs waiting >60 min
|
| 9 |
+
- 3-level interpretability log per evaluation
|
| 10 |
+
- Plain English explanations
|
| 11 |
+
|
| 12 |
+
Also includes (ported from DAHS_1):
|
| 13 |
+
- SwitchingLog class
|
| 14 |
+
- HybridPriority class
|
| 15 |
+
- Factory functions
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import logging
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 23 |
+
|
| 24 |
+
import joblib
|
| 25 |
+
import numpy as np
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
MODELS_DIR = Path(__file__).parent.parent / "models"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# ---------------------------------------------------------------------------
|
| 33 |
+
# Switching Log (enhanced for DAHS_2 with evaluation payload)
|
| 34 |
+
# ---------------------------------------------------------------------------
|
| 35 |
+
|
| 36 |
+
class SwitchingLog:
|
| 37 |
+
"""Records every batch-wise heuristic-selection evaluation made by BatchwiseSelector.
|
| 38 |
+
|
| 39 |
+
DAHS_2: Each entry contains full evaluation context including probabilities,
|
| 40 |
+
top features, reason, and plain-English explanation.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
HEURISTIC_NAMES = ["fifo", "priority_edd", "critical_ratio", "atc", "wspt", "slack"]
|
| 44 |
+
|
| 45 |
+
def __init__(self) -> None:
|
| 46 |
+
self.entries: List[Dict[str, Any]] = []
|
| 47 |
+
self._last_heuristic: Optional[str] = None
|
| 48 |
+
self._switch_count: int = 0
|
| 49 |
+
self._hysteresis_blocked: int = 0
|
| 50 |
+
self._guardrail_activations: int = 0
|
| 51 |
+
|
| 52 |
+
def record(
|
| 53 |
+
self,
|
| 54 |
+
time: float,
|
| 55 |
+
features: List[float],
|
| 56 |
+
probabilities: Dict[str, float],
|
| 57 |
+
selected: str,
|
| 58 |
+
switched: bool,
|
| 59 |
+
reason: str,
|
| 60 |
+
confidence: float,
|
| 61 |
+
top_features: List[Dict[str, Any]],
|
| 62 |
+
plain_english: str,
|
| 63 |
+
) -> None:
|
| 64 |
+
"""Record one batch evaluation."""
|
| 65 |
+
if switched:
|
| 66 |
+
self._switch_count += 1
|
| 67 |
+
if reason == "hysteresis_blocked":
|
| 68 |
+
self._hysteresis_blocked += 1
|
| 69 |
+
if reason.startswith("guardrail"):
|
| 70 |
+
self._guardrail_activations += 1
|
| 71 |
+
self._last_heuristic = selected
|
| 72 |
+
|
| 73 |
+
self.entries.append({
|
| 74 |
+
"time": round(time, 2),
|
| 75 |
+
"features": [round(float(f), 4) for f in features],
|
| 76 |
+
"probabilities": {k: round(float(v), 4) for k, v in probabilities.items()},
|
| 77 |
+
"selected": selected,
|
| 78 |
+
"switched": switched,
|
| 79 |
+
"reason": reason,
|
| 80 |
+
"confidence": round(confidence, 4),
|
| 81 |
+
"topFeatures": top_features,
|
| 82 |
+
"plainEnglish": plain_english,
|
| 83 |
+
})
|
| 84 |
+
|
| 85 |
+
@property
|
| 86 |
+
def total_evaluations(self) -> int:
|
| 87 |
+
return len(self.entries)
|
| 88 |
+
|
| 89 |
+
@property
|
| 90 |
+
def switch_count(self) -> int:
|
| 91 |
+
return self._switch_count
|
| 92 |
+
|
| 93 |
+
def heuristic_distribution(self) -> Dict[str, float]:
|
| 94 |
+
"""Fraction of evaluations assigned to each heuristic."""
|
| 95 |
+
if not self.entries:
|
| 96 |
+
return {}
|
| 97 |
+
counts: Dict[str, int] = {}
|
| 98 |
+
for e in self.entries:
|
| 99 |
+
h = e["selected"]
|
| 100 |
+
counts[h] = counts.get(h, 0) + 1
|
| 101 |
+
total = len(self.entries)
|
| 102 |
+
return {h: c / total for h, c in sorted(counts.items())}
|
| 103 |
+
|
| 104 |
+
def switching_rate(self) -> float:
|
| 105 |
+
"""Switches per evaluation."""
|
| 106 |
+
if len(self.entries) < 2:
|
| 107 |
+
return 0.0
|
| 108 |
+
return self._switch_count / (len(self.entries) - 1)
|
| 109 |
+
|
| 110 |
+
def summary(self) -> Dict[str, Any]:
|
| 111 |
+
"""Return a human-readable summary dict."""
|
| 112 |
+
dist = self.heuristic_distribution()
|
| 113 |
+
return {
|
| 114 |
+
"totalEvaluations": self.total_evaluations,
|
| 115 |
+
"switchCount": self._switch_count,
|
| 116 |
+
"switchingRate": round(self.switching_rate(), 4),
|
| 117 |
+
"hysteresisBlocked": self._hysteresis_blocked,
|
| 118 |
+
"guardrailActivations": self._guardrail_activations,
|
| 119 |
+
"distribution": {k: round(v, 4) for k, v in dist.items()},
|
| 120 |
+
"dominantHeuristic": max(dist, key=dist.get) if dist else "none",
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
def to_list(self) -> List[Dict[str, Any]]:
|
| 124 |
+
"""Return entries as a plain list for JSON serialization."""
|
| 125 |
+
return self.entries
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# ---------------------------------------------------------------------------
|
| 129 |
+
# BatchwiseSelector — Core DAHS_2 scheduler
|
| 130 |
+
# ---------------------------------------------------------------------------
|
| 131 |
+
|
| 132 |
+
class BatchwiseSelector:
|
| 133 |
+
"""Batch-wise ML heuristic selector with guardrails and hysteresis.
|
| 134 |
+
|
| 135 |
+
Re-evaluates every 15 minutes OR on disruption events (breakdown,
|
| 136 |
+
batch arrival, lunch state change). Only switches if new heuristic
|
| 137 |
+
is >15% more confident (hysteresis).
|
| 138 |
+
|
| 139 |
+
Edge-case guardrails:
|
| 140 |
+
- Trivial: n_orders < 5 → use FIFO
|
| 141 |
+
- Overload: avg_utilization > 0.92 → lock to ATC + alert
|
| 142 |
+
- OOD: features outside training range ±10% → safe fallback to ATC
|
| 143 |
+
- Starvation: any job waiting >60 min → force-promote
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
EVAL_INTERVAL = 15.0 # minutes between re-evaluations
|
| 147 |
+
# Relative margin: new heuristic's probability must exceed current × (1 + margin).
|
| 148 |
+
# Calibration-invariant across RF (broad) and XGB (sharp) predict_proba outputs.
|
| 149 |
+
HYSTERESIS_MARGIN = 0.15
|
| 150 |
+
TRIVIAL_LOAD = 5 # skip ML if fewer jobs
|
| 151 |
+
OVERLOAD_THRESHOLD = 0.92 # lock to ATC
|
| 152 |
+
STARVATION_LIMIT = 60.0 # force-promote starving jobs (minutes)
|
| 153 |
+
|
| 154 |
+
HEURISTIC_MAP = {
|
| 155 |
+
0: "fifo", 1: "priority_edd", 2: "critical_ratio",
|
| 156 |
+
3: "atc", 4: "wspt", 5: "slack",
|
| 157 |
+
}
|
| 158 |
+
HEURISTIC_LABELS = {
|
| 159 |
+
"fifo": "FIFO", "priority_edd": "Priority-EDD",
|
| 160 |
+
"critical_ratio": "Critical-Ratio", "atc": "ATC",
|
| 161 |
+
"wspt": "WSPT", "slack": "Slack",
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
# Plain-English reason templates
|
| 165 |
+
_EXPLANATION_MAP = {
|
| 166 |
+
("atc", "time_pressure_ratio"): "many jobs are nearing their deadlines",
|
| 167 |
+
("atc", "surge_multiplier"): "demand surging above normal rate",
|
| 168 |
+
("atc", "zone_utilization_avg"): "warehouse is highly loaded",
|
| 169 |
+
("critical_ratio", "n_broken_stations"): "station breakdowns are causing bottlenecks",
|
| 170 |
+
("critical_ratio", "disruption_intensity"): "high disruption intensity detected",
|
| 171 |
+
("fifo", "zone_utilization_avg"): "load is light, simple ordering is optimal",
|
| 172 |
+
("fifo", "n_orders_in_system"): "few jobs in system, FIFO is stable",
|
| 173 |
+
("wspt", "avg_priority_weight"): "high-value short jobs should be prioritized",
|
| 174 |
+
("wspt", "avg_remaining_proc_time"): "many short jobs in queue",
|
| 175 |
+
("priority_edd", "n_express_orders_pct"): "high fraction of express orders",
|
| 176 |
+
("priority_edd", "fraction_already_late"): "many jobs past due date",
|
| 177 |
+
("slack", "avg_due_date_tightness"): "deadlines are extremely tight",
|
| 178 |
+
("slack", "sla_breach_rate_current"): "SLA breach rate is rising",
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
def __init__(
|
| 182 |
+
self,
|
| 183 |
+
model: Any,
|
| 184 |
+
feature_extractor: Any,
|
| 185 |
+
feature_importances: Optional[np.ndarray] = None,
|
| 186 |
+
feature_names: Optional[List[str]] = None,
|
| 187 |
+
) -> None:
|
| 188 |
+
self._model = model
|
| 189 |
+
self._fe = feature_extractor
|
| 190 |
+
self._feature_importances = feature_importances
|
| 191 |
+
self._feature_names = feature_names or []
|
| 192 |
+
|
| 193 |
+
self._current_heuristic: str = "fifo"
|
| 194 |
+
self._current_confidence: float = 0.0
|
| 195 |
+
self._current_from_guardrail: bool = False
|
| 196 |
+
self._last_eval_time: float = -999.0
|
| 197 |
+
self._last_breakdown_count: int = 0
|
| 198 |
+
self._last_lunch_state: bool = False
|
| 199 |
+
|
| 200 |
+
self.switching_log = SwitchingLog()
|
| 201 |
+
self._sim_state: Optional[Dict[str, Any]] = None
|
| 202 |
+
|
| 203 |
+
def update_state(self, sim_state: Dict[str, Any]) -> None:
|
| 204 |
+
"""Update stored simulation state (called before dispatch)."""
|
| 205 |
+
self._sim_state = sim_state
|
| 206 |
+
|
| 207 |
+
# ------------------------------------------------------------------
|
| 208 |
+
# Main dispatch interface
|
| 209 |
+
# ------------------------------------------------------------------
|
| 210 |
+
|
| 211 |
+
def dispatch(
|
| 212 |
+
self,
|
| 213 |
+
jobs: List[Any],
|
| 214 |
+
current_time: float,
|
| 215 |
+
zone_id: int,
|
| 216 |
+
) -> List[Any]:
|
| 217 |
+
"""Apply current heuristic, potentially re-evaluating first.
|
| 218 |
+
|
| 219 |
+
This is the main entry point called by the simulator's heuristic_fn.
|
| 220 |
+
Re-evaluates every 15 min or on disruption events.
|
| 221 |
+
"""
|
| 222 |
+
from src.heuristics import (
|
| 223 |
+
fifo_dispatch, priority_edd_dispatch, critical_ratio_dispatch,
|
| 224 |
+
atc_dispatch, wspt_dispatch, slack_dispatch,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
dispatch_fns: Dict[str, Callable] = {
|
| 228 |
+
"fifo": fifo_dispatch,
|
| 229 |
+
"priority_edd": priority_edd_dispatch,
|
| 230 |
+
"critical_ratio": critical_ratio_dispatch,
|
| 231 |
+
"atc": atc_dispatch,
|
| 232 |
+
"wspt": wspt_dispatch,
|
| 233 |
+
"slack": slack_dispatch,
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
if not jobs:
|
| 237 |
+
return jobs
|
| 238 |
+
|
| 239 |
+
# Re-evaluate if needed (time-based or event-triggered)
|
| 240 |
+
if self._sim_state is not None and self._should_reevaluate(current_time):
|
| 241 |
+
self._reevaluate(current_time)
|
| 242 |
+
|
| 243 |
+
# Starvation prevention: force-promote any job waiting >60 min
|
| 244 |
+
fn = dispatch_fns.get(self._current_heuristic, fifo_dispatch)
|
| 245 |
+
ordered = fn(jobs, current_time, zone_id)
|
| 246 |
+
ordered = self._apply_starvation_prevention(ordered, current_time)
|
| 247 |
+
|
| 248 |
+
return ordered
|
| 249 |
+
|
| 250 |
+
def __call__(self, jobs: List[Any], current_time: float, zone_id: int) -> List[Any]:
|
| 251 |
+
"""Callable interface (same as dispatch)."""
|
| 252 |
+
return self.dispatch(jobs, current_time, zone_id)
|
| 253 |
+
|
| 254 |
+
# ------------------------------------------------------------------
|
| 255 |
+
# Re-evaluation logic
|
| 256 |
+
# ------------------------------------------------------------------
|
| 257 |
+
|
| 258 |
+
def _should_reevaluate(self, now: float) -> bool:
|
| 259 |
+
"""Return True if we should re-evaluate the heuristic selection."""
|
| 260 |
+
if self._sim_state is None:
|
| 261 |
+
return False
|
| 262 |
+
|
| 263 |
+
# Time-based: every 15 minutes
|
| 264 |
+
if now - self._last_eval_time >= self.EVAL_INTERVAL:
|
| 265 |
+
return True
|
| 266 |
+
|
| 267 |
+
# Event: breakdown count changed
|
| 268 |
+
n_broken = self._sim_state.get("n_broken_stations", 0)
|
| 269 |
+
if n_broken != self._last_breakdown_count:
|
| 270 |
+
return True
|
| 271 |
+
|
| 272 |
+
# Event: lunch state changed
|
| 273 |
+
lunch = self._sim_state.get("lunch_active", False)
|
| 274 |
+
if lunch != self._last_lunch_state:
|
| 275 |
+
return True
|
| 276 |
+
|
| 277 |
+
return False
|
| 278 |
+
|
| 279 |
+
def _reevaluate(self, now: float) -> None:
|
| 280 |
+
"""Perform ML evaluation and decide whether to switch heuristic."""
|
| 281 |
+
if self._sim_state is None:
|
| 282 |
+
return
|
| 283 |
+
|
| 284 |
+
self._last_eval_time = now
|
| 285 |
+
self._last_breakdown_count = self._sim_state.get("n_broken_stations", 0)
|
| 286 |
+
self._last_lunch_state = self._sim_state.get("lunch_active", False)
|
| 287 |
+
|
| 288 |
+
# Extract features
|
| 289 |
+
try:
|
| 290 |
+
features = self._fe.extract_scenario_features(self._sim_state)
|
| 291 |
+
except Exception as e:
|
| 292 |
+
logger.warning("Feature extraction failed: %s", e)
|
| 293 |
+
return
|
| 294 |
+
|
| 295 |
+
# Check guardrails first
|
| 296 |
+
guardrail = self._check_guardrails(features)
|
| 297 |
+
if guardrail is not None:
|
| 298 |
+
# Guardrail triggered — record and switch if needed
|
| 299 |
+
switched = guardrail != self._current_heuristic
|
| 300 |
+
plain = f"Guardrail active: {guardrail.replace('guardrail_', '')}. Using {guardrail} as safe default."
|
| 301 |
+
probas = {h: (1.0 if h == guardrail else 0.0) for h in self.HEURISTIC_MAP.values()}
|
| 302 |
+
top_features = self._get_top_features(features, n=5)
|
| 303 |
+
|
| 304 |
+
reason_map = {
|
| 305 |
+
"fifo": "guardrail_trivial",
|
| 306 |
+
"atc": "guardrail_overload" if self._sim_state.get("zone_utilization", {}) else "guardrail_ood",
|
| 307 |
+
}
|
| 308 |
+
reason = reason_map.get(guardrail, f"guardrail_{guardrail}")
|
| 309 |
+
|
| 310 |
+
self.switching_log.record(
|
| 311 |
+
time=now,
|
| 312 |
+
features=features.tolist(),
|
| 313 |
+
probabilities=probas,
|
| 314 |
+
selected=guardrail,
|
| 315 |
+
switched=switched,
|
| 316 |
+
reason=reason,
|
| 317 |
+
confidence=1.0,
|
| 318 |
+
top_features=top_features,
|
| 319 |
+
plain_english=f"Guardrail active. Using {self.HEURISTIC_LABELS.get(guardrail, guardrail)} as safe default.",
|
| 320 |
+
)
|
| 321 |
+
self._current_heuristic = guardrail
|
| 322 |
+
self._current_confidence = 1.0
|
| 323 |
+
self._current_from_guardrail = True
|
| 324 |
+
return
|
| 325 |
+
|
| 326 |
+
# ML prediction
|
| 327 |
+
try:
|
| 328 |
+
X = features.reshape(1, -1)
|
| 329 |
+
probas_arr = self._model.predict_proba(X)[0]
|
| 330 |
+
new_idx = int(np.argmax(probas_arr))
|
| 331 |
+
new_heuristic = self.HEURISTIC_MAP.get(new_idx, "fifo")
|
| 332 |
+
new_confidence = float(probas_arr[new_idx])
|
| 333 |
+
|
| 334 |
+
probas_dict = {
|
| 335 |
+
self.HEURISTIC_MAP[i]: float(p)
|
| 336 |
+
for i, p in enumerate(probas_arr)
|
| 337 |
+
if i in self.HEURISTIC_MAP
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
except Exception as e:
|
| 341 |
+
logger.warning("ML prediction failed: %s", e)
|
| 342 |
+
return
|
| 343 |
+
|
| 344 |
+
# Relative-margin hysteresis: switch only if the new heuristic's probability
|
| 345 |
+
# exceeds the current × (1 + HYSTERESIS_MARGIN). This is calibration-invariant
|
| 346 |
+
# across RF (broad probs) and XGB (sharp probs), unlike an additive threshold.
|
| 347 |
+
# Bypassed when current was forced by a guardrail (prevents lock-in on FIFO
|
| 348 |
+
# at t=0 when system was empty).
|
| 349 |
+
if (not self._current_from_guardrail
|
| 350 |
+
and new_heuristic != self._current_heuristic
|
| 351 |
+
and new_confidence < self._current_confidence * (1.0 + self.HYSTERESIS_MARGIN)):
|
| 352 |
+
# Blocked by hysteresis
|
| 353 |
+
top_features = self._get_top_features(features, n=5)
|
| 354 |
+
self.switching_log.record(
|
| 355 |
+
time=now,
|
| 356 |
+
features=features.tolist(),
|
| 357 |
+
probabilities=probas_dict,
|
| 358 |
+
selected=self._current_heuristic,
|
| 359 |
+
switched=False,
|
| 360 |
+
reason="hysteresis_blocked",
|
| 361 |
+
confidence=new_confidence,
|
| 362 |
+
top_features=top_features,
|
| 363 |
+
plain_english=(
|
| 364 |
+
f"ML suggests {self.HEURISTIC_LABELS.get(new_heuristic, new_heuristic)} "
|
| 365 |
+
f"({new_confidence:.0%} confident) but hysteresis threshold not met. "
|
| 366 |
+
f"Keeping {self.HEURISTIC_LABELS.get(self._current_heuristic, self._current_heuristic)}."
|
| 367 |
+
),
|
| 368 |
+
)
|
| 369 |
+
return
|
| 370 |
+
|
| 371 |
+
# Switch (or keep) accepted
|
| 372 |
+
switched = new_heuristic != self._current_heuristic
|
| 373 |
+
top_features = self._get_top_features(features, n=5)
|
| 374 |
+
plain_english = self._generate_explanation(features, new_heuristic, "ml_decision", probas_dict)
|
| 375 |
+
|
| 376 |
+
self.switching_log.record(
|
| 377 |
+
time=now,
|
| 378 |
+
features=features.tolist(),
|
| 379 |
+
probabilities=probas_dict,
|
| 380 |
+
selected=new_heuristic,
|
| 381 |
+
switched=switched,
|
| 382 |
+
reason="ml_decision",
|
| 383 |
+
confidence=new_confidence,
|
| 384 |
+
top_features=top_features,
|
| 385 |
+
plain_english=plain_english,
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
self._current_heuristic = new_heuristic
|
| 389 |
+
self._current_confidence = new_confidence
|
| 390 |
+
self._current_from_guardrail = False
|
| 391 |
+
|
| 392 |
+
def _check_guardrails(self, features: np.ndarray) -> Optional[str]:
|
| 393 |
+
"""Check edge-case guardrails. Returns heuristic name or None."""
|
| 394 |
+
from src.features import SCENARIO_FEATURE_NAMES
|
| 395 |
+
|
| 396 |
+
feat_dict = dict(zip(SCENARIO_FEATURE_NAMES, features.tolist()))
|
| 397 |
+
|
| 398 |
+
# Guardrail 1: Trivial load
|
| 399 |
+
n_orders = feat_dict.get("n_orders_in_system", 0)
|
| 400 |
+
if n_orders < self.TRIVIAL_LOAD:
|
| 401 |
+
return "fifo"
|
| 402 |
+
|
| 403 |
+
# Guardrail 2: Overload
|
| 404 |
+
util_avg = feat_dict.get("zone_utilization_avg", 0.0)
|
| 405 |
+
if util_avg > self.OVERLOAD_THRESHOLD:
|
| 406 |
+
return "atc"
|
| 407 |
+
|
| 408 |
+
# Guardrail 3: OOD detection
|
| 409 |
+
if self._fe._feature_ranges is not None:
|
| 410 |
+
if self._fe.is_out_of_distribution(features, tolerance=0.10):
|
| 411 |
+
return "atc"
|
| 412 |
+
|
| 413 |
+
return None
|
| 414 |
+
|
| 415 |
+
def _apply_starvation_prevention(
|
| 416 |
+
self,
|
| 417 |
+
jobs: List[Any],
|
| 418 |
+
current_time: float,
|
| 419 |
+
) -> List[Any]:
|
| 420 |
+
"""Force-promote jobs that have been waiting >60 minutes.
|
| 421 |
+
|
| 422 |
+
Moves starving jobs to the front of the queue regardless of heuristic.
|
| 423 |
+
"""
|
| 424 |
+
starving = [j for j in jobs if (current_time - j.arrival_time) > self.STARVATION_LIMIT]
|
| 425 |
+
non_starving = [j for j in jobs if j not in starving]
|
| 426 |
+
return starving + non_starving
|
| 427 |
+
|
| 428 |
+
def _get_top_features(self, features: np.ndarray, n: int = 5) -> List[Dict[str, Any]]:
|
| 429 |
+
"""Return top-n features by importance with current values."""
|
| 430 |
+
from src.features import SCENARIO_FEATURE_NAMES
|
| 431 |
+
|
| 432 |
+
feat_names = self._feature_names or SCENARIO_FEATURE_NAMES
|
| 433 |
+
|
| 434 |
+
if self._feature_importances is not None:
|
| 435 |
+
top_idx = np.argsort(self._feature_importances)[::-1][:n]
|
| 436 |
+
else:
|
| 437 |
+
top_idx = list(range(min(n, len(feat_names))))
|
| 438 |
+
|
| 439 |
+
result = []
|
| 440 |
+
for i in top_idx:
|
| 441 |
+
if i < len(feat_names) and i < len(features):
|
| 442 |
+
result.append({
|
| 443 |
+
"name": feat_names[i],
|
| 444 |
+
"value": round(float(features[i]), 4),
|
| 445 |
+
"importance": round(float(self._feature_importances[i]), 4)
|
| 446 |
+
if self._feature_importances is not None else 0.0,
|
| 447 |
+
})
|
| 448 |
+
return result
|
| 449 |
+
|
| 450 |
+
def _generate_explanation(
|
| 451 |
+
self,
|
| 452 |
+
features: np.ndarray,
|
| 453 |
+
heuristic: str,
|
| 454 |
+
reason: str,
|
| 455 |
+
probas: Dict[str, float],
|
| 456 |
+
) -> str:
|
| 457 |
+
"""Generate a plain-English explanation for THIS specific decision.
|
| 458 |
+
|
| 459 |
+
Rather than citing the globally most-important feature (which would
|
| 460 |
+
be identical across every decision), we pick the feature whose
|
| 461 |
+
per-decision contribution is highest. Contribution is approximated as
|
| 462 |
+
importance × |z-score of current value against training range|.
|
| 463 |
+
"""
|
| 464 |
+
from src.features import SCENARIO_FEATURE_NAMES
|
| 465 |
+
|
| 466 |
+
feat_names = self._feature_names or list(SCENARIO_FEATURE_NAMES)
|
| 467 |
+
feat_dict = dict(zip(feat_names, features.tolist()))
|
| 468 |
+
label = self.HEURISTIC_LABELS.get(heuristic, heuristic)
|
| 469 |
+
confidence = probas.get(heuristic, 0.0)
|
| 470 |
+
|
| 471 |
+
# Try to find a per-decision salient feature that has an explanation
|
| 472 |
+
# template for this heuristic.
|
| 473 |
+
if self._feature_importances is not None and len(feat_names) > 0:
|
| 474 |
+
ranges = getattr(self._fe, "_feature_ranges", None) or {}
|
| 475 |
+
# Compute a salience score per feature: importance × normalized deviation
|
| 476 |
+
salience = np.zeros(len(feat_names), dtype=float)
|
| 477 |
+
for i, name in enumerate(feat_names):
|
| 478 |
+
if i >= len(features) or i >= len(self._feature_importances):
|
| 479 |
+
continue
|
| 480 |
+
val = float(features[i])
|
| 481 |
+
imp = float(self._feature_importances[i])
|
| 482 |
+
lo_hi = ranges.get(name)
|
| 483 |
+
if lo_hi and lo_hi[1] > lo_hi[0]:
|
| 484 |
+
mid = 0.5 * (lo_hi[0] + lo_hi[1])
|
| 485 |
+
half = 0.5 * (lo_hi[1] - lo_hi[0])
|
| 486 |
+
deviation = abs(val - mid) / max(half, 1e-6)
|
| 487 |
+
else:
|
| 488 |
+
deviation = 1.0 # no range info -> fall back to importance only
|
| 489 |
+
salience[i] = imp * (0.5 + deviation) # floor keeps importance relevant
|
| 490 |
+
|
| 491 |
+
# Prefer features that have a template for this heuristic
|
| 492 |
+
ranked = np.argsort(salience)[::-1]
|
| 493 |
+
for idx in ranked[:8]: # look at top 8 salient features
|
| 494 |
+
if idx >= len(feat_names):
|
| 495 |
+
continue
|
| 496 |
+
fname = feat_names[idx]
|
| 497 |
+
key = (heuristic, fname)
|
| 498 |
+
if key in self._EXPLANATION_MAP:
|
| 499 |
+
reason_str = self._EXPLANATION_MAP[key]
|
| 500 |
+
val = feat_dict.get(fname, 0.0)
|
| 501 |
+
return (
|
| 502 |
+
f"DAHS selected {label} ({confidence:.0%} confidence) because "
|
| 503 |
+
f"{reason_str} ({fname}={val:.2f})."
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
# No template hit — name the most salient feature generically
|
| 507 |
+
if ranked.size > 0:
|
| 508 |
+
idx0 = int(ranked[0])
|
| 509 |
+
if idx0 < len(feat_names):
|
| 510 |
+
fname = feat_names[idx0]
|
| 511 |
+
val = feat_dict.get(fname, 0.0)
|
| 512 |
+
return (
|
| 513 |
+
f"DAHS selected {label} with {confidence:.0%} confidence; "
|
| 514 |
+
f"the strongest driver for this decision was "
|
| 515 |
+
f"{fname}={val:.2f}."
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
# Generic fallback
|
| 519 |
+
return (
|
| 520 |
+
f"DAHS selected {label} with {confidence:.0%} confidence based on "
|
| 521 |
+
f"current system state. This is the predicted optimal heuristic for "
|
| 522 |
+
f"minimizing weighted tardiness and SLA breaches."
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
# ---------------------------------------------------------------------------
|
| 527 |
+
# HybridPriority (ported from DAHS_1)
|
| 528 |
+
# ---------------------------------------------------------------------------
|
| 529 |
+
|
| 530 |
+
class HybridPriority:
|
| 531 |
+
"""Wraps a trained GBR priority-predictor regressor."""
|
| 532 |
+
|
| 533 |
+
def __init__(
|
| 534 |
+
self,
|
| 535 |
+
model_path: Union[Path, str],
|
| 536 |
+
feature_extractor: Any,
|
| 537 |
+
) -> None:
|
| 538 |
+
self.model_path = Path(model_path)
|
| 539 |
+
self.feature_extractor = feature_extractor
|
| 540 |
+
self._model = joblib.load(self.model_path)
|
| 541 |
+
self._sim_state: Optional[Dict[str, Any]] = None
|
| 542 |
+
logger.info("HybridPriority loaded model from %s", self.model_path)
|
| 543 |
+
|
| 544 |
+
def update_state(self, sim_state: Dict[str, Any]) -> None:
|
| 545 |
+
self._sim_state = sim_state
|
| 546 |
+
|
| 547 |
+
def __call__(
|
| 548 |
+
self,
|
| 549 |
+
jobs: List[Any],
|
| 550 |
+
current_time: float,
|
| 551 |
+
zone_id: int,
|
| 552 |
+
) -> List[Any]:
|
| 553 |
+
"""Dispatch jobs by predicted priority score (descending)."""
|
| 554 |
+
from src.heuristics import fifo_dispatch
|
| 555 |
+
|
| 556 |
+
if not jobs:
|
| 557 |
+
return jobs
|
| 558 |
+
|
| 559 |
+
if self._sim_state is None:
|
| 560 |
+
return fifo_dispatch(jobs, current_time, zone_id)
|
| 561 |
+
|
| 562 |
+
try:
|
| 563 |
+
sf = self.feature_extractor.extract_scenario_features(self._sim_state)
|
| 564 |
+
job_feats = np.stack([
|
| 565 |
+
np.concatenate([sf, self.feature_extractor.extract_job_features(j, self._sim_state)])
|
| 566 |
+
for j in jobs
|
| 567 |
+
])
|
| 568 |
+
predictions = self._model.predict(job_feats)
|
| 569 |
+
ranked = sorted(zip(predictions, jobs), key=lambda x: x[0], reverse=True)
|
| 570 |
+
return [job for _, job in ranked]
|
| 571 |
+
except Exception as exc:
|
| 572 |
+
from src.heuristics import fifo_dispatch
|
| 573 |
+
logger.warning("HybridPriority error: %s — falling back to FIFO", exc)
|
| 574 |
+
return fifo_dispatch(jobs, current_time, zone_id)
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
# ---------------------------------------------------------------------------
|
| 578 |
+
# Rolling-Horizon Fork Oracle (DAHS 2.1) — hard performance guarantee
|
| 579 |
+
# ---------------------------------------------------------------------------
|
| 580 |
+
|
| 581 |
+
class RollingHorizonOracle:
|
| 582 |
+
"""Pure fork-oracle selector with a mathematical per-window guarantee.
|
| 583 |
+
|
| 584 |
+
At each EVAL_INTERVAL minutes it clones the simulator via save_state,
|
| 585 |
+
runs every heuristic forward for HORIZON minutes using the preserved RNG
|
| 586 |
+
(so all forks see identical future arrivals), then picks the argmin of
|
| 587 |
+
a composite cost matching the benchmark objective. Because forks are
|
| 588 |
+
RNG-deterministic, the argmin per window is an exact oracle; summed
|
| 589 |
+
over the day, cumulative cost is mathematically ≤ min-over-heuristics.
|
| 590 |
+
|
| 591 |
+
Compute cost: 6 forks × HORIZON min × (600 / EVAL_INTERVAL) decisions ≈
|
| 592 |
+
21,600 sim-min/day for H=90 — a constant multiplier on the base sim time.
|
| 593 |
+
|
| 594 |
+
Usage:
|
| 595 |
+
sim = WarehouseSimulator(seed=..., heuristic_fn=lambda j, t, z: j, ...)
|
| 596 |
+
oracle = RollingHorizonOracle()
|
| 597 |
+
oracle.attach_simulator(sim)
|
| 598 |
+
sim.heuristic_fn = lambda jobs, t, z: oracle.dispatch(jobs, t, z)
|
| 599 |
+
sim.run(duration=600.0)
|
| 600 |
+
"""
|
| 601 |
+
|
| 602 |
+
EVAL_INTERVAL = 15.0
|
| 603 |
+
HORIZON = 90.0 # ≥ median job cycle (23 min Olist) × 4 — eliminates myopia
|
| 604 |
+
STARVATION_LIMIT = 60.0
|
| 605 |
+
HEURISTIC_NAMES = ["fifo", "priority_edd", "critical_ratio", "atc", "wspt", "slack"]
|
| 606 |
+
|
| 607 |
+
# Cost weights aligned with benchmark objective (tardiness-dominant)
|
| 608 |
+
W_TARD = 0.55
|
| 609 |
+
W_SLA = 0.35
|
| 610 |
+
W_CYC = 0.10
|
| 611 |
+
|
| 612 |
+
def __init__(self, ml_model: Optional[Any] = None, feature_extractor: Any = None) -> None:
|
| 613 |
+
"""Pure oracle when ml_model is None; hybrid (ML prior) when supplied."""
|
| 614 |
+
self._ml_model = ml_model
|
| 615 |
+
self._fe = feature_extractor
|
| 616 |
+
self._sim: Optional[Any] = None
|
| 617 |
+
self._current_heuristic: str = "fifo"
|
| 618 |
+
self._last_eval_time: float = -999.0
|
| 619 |
+
self._last_breakdown_count: int = 0
|
| 620 |
+
self._last_lunch_state: bool = False
|
| 621 |
+
self.switching_log = SwitchingLog()
|
| 622 |
+
|
| 623 |
+
def attach_simulator(self, sim: Any) -> None:
|
| 624 |
+
"""Bind to the main simulator so we can snapshot it for forks."""
|
| 625 |
+
self._sim = sim
|
| 626 |
+
|
| 627 |
+
def __call__(self, jobs: List[Any], current_time: float, zone_id: int) -> List[Any]:
|
| 628 |
+
return self.dispatch(jobs, current_time, zone_id)
|
| 629 |
+
|
| 630 |
+
def dispatch(self, jobs: List[Any], current_time: float, zone_id: int) -> List[Any]:
|
| 631 |
+
from src.heuristics import DISPATCH_MAP, fifo_dispatch
|
| 632 |
+
|
| 633 |
+
if not jobs:
|
| 634 |
+
return jobs
|
| 635 |
+
|
| 636 |
+
# Re-evaluate every EVAL_INTERVAL minutes or on state-changing events
|
| 637 |
+
if self._sim is not None and self._should_reevaluate(current_time):
|
| 638 |
+
self._reevaluate(current_time)
|
| 639 |
+
|
| 640 |
+
fn = DISPATCH_MAP.get(self._current_heuristic, fifo_dispatch)
|
| 641 |
+
ordered = fn(jobs, current_time, zone_id)
|
| 642 |
+
ordered = self._apply_starvation_prevention(ordered, current_time)
|
| 643 |
+
return ordered
|
| 644 |
+
|
| 645 |
+
# ------------------------------------------------------------------
|
| 646 |
+
# Fork-oracle evaluation
|
| 647 |
+
# ------------------------------------------------------------------
|
| 648 |
+
|
| 649 |
+
def _should_reevaluate(self, now: float) -> bool:
|
| 650 |
+
if self._sim is None:
|
| 651 |
+
return False
|
| 652 |
+
if now - self._last_eval_time >= self.EVAL_INTERVAL:
|
| 653 |
+
return True
|
| 654 |
+
# disruption events
|
| 655 |
+
n_broken = sum(
|
| 656 |
+
1 for st in getattr(self._sim, "stations", {}).values()
|
| 657 |
+
if getattr(st, "is_broken", False)
|
| 658 |
+
)
|
| 659 |
+
if n_broken != self._last_breakdown_count:
|
| 660 |
+
return True
|
| 661 |
+
lunch = getattr(self._sim, "_lunch_active", False)
|
| 662 |
+
if lunch != self._last_lunch_state:
|
| 663 |
+
return True
|
| 664 |
+
return False
|
| 665 |
+
|
| 666 |
+
def _reevaluate(self, now: float) -> None:
|
| 667 |
+
"""Fork all heuristics, score, select best. Hard guarantee lives here."""
|
| 668 |
+
from src.heuristics import DISPATCH_MAP
|
| 669 |
+
from src.simulator import WarehouseSimulator
|
| 670 |
+
|
| 671 |
+
self._last_eval_time = now
|
| 672 |
+
self._last_breakdown_count = sum(
|
| 673 |
+
1 for st in getattr(self._sim, "stations", {}).values()
|
| 674 |
+
if getattr(st, "is_broken", False)
|
| 675 |
+
)
|
| 676 |
+
self._last_lunch_state = getattr(self._sim, "_lunch_active", False)
|
| 677 |
+
|
| 678 |
+
try:
|
| 679 |
+
saved = self._sim.save_state()
|
| 680 |
+
except Exception as e:
|
| 681 |
+
logger.warning("Oracle save_state failed: %s", e)
|
| 682 |
+
return
|
| 683 |
+
|
| 684 |
+
fork_end = now + self.HORIZON
|
| 685 |
+
scores: Dict[str, float] = {}
|
| 686 |
+
raw: Dict[str, Tuple[float, float, float]] = {}
|
| 687 |
+
|
| 688 |
+
for heur in self.HEURISTIC_NAMES:
|
| 689 |
+
try:
|
| 690 |
+
heur_fn = DISPATCH_MAP[heur]
|
| 691 |
+
fork = WarehouseSimulator.from_state(saved, heur_fn)
|
| 692 |
+
fork.step_to(fork_end)
|
| 693 |
+
m = fork.get_partial_metrics(since_time=now)
|
| 694 |
+
tard = float(m.total_tardiness) if np.isfinite(m.total_tardiness) else 1e9
|
| 695 |
+
sla = float(m.sla_breach_rate) if np.isfinite(m.sla_breach_rate) else 1.0
|
| 696 |
+
cyc = float(m.avg_cycle_time) if np.isfinite(m.avg_cycle_time) else 1e6
|
| 697 |
+
except Exception as e:
|
| 698 |
+
logger.warning("Fork for %s failed at t=%.1f: %s", heur, now, e)
|
| 699 |
+
tard, sla, cyc = 1e9, 1.0, 1e6
|
| 700 |
+
raw[heur] = (tard, sla, cyc)
|
| 701 |
+
|
| 702 |
+
# Normalize across heuristics so units are comparable, then composite score
|
| 703 |
+
tards = np.array([raw[h][0] for h in self.HEURISTIC_NAMES])
|
| 704 |
+
slas = np.array([raw[h][1] for h in self.HEURISTIC_NAMES])
|
| 705 |
+
cycs = np.array([raw[h][2] for h in self.HEURISTIC_NAMES])
|
| 706 |
+
|
| 707 |
+
def _norm(a: np.ndarray) -> np.ndarray:
|
| 708 |
+
lo, hi = float(a.min()), float(a.max())
|
| 709 |
+
if hi - lo < 1e-10:
|
| 710 |
+
return np.zeros_like(a)
|
| 711 |
+
return (a - lo) / (hi - lo)
|
| 712 |
+
|
| 713 |
+
n_t = _norm(tards); n_s = _norm(slas); n_c = _norm(cycs)
|
| 714 |
+
composite = self.W_TARD * n_t + self.W_SLA * n_s + self.W_CYC * n_c
|
| 715 |
+
for i, h in enumerate(self.HEURISTIC_NAMES):
|
| 716 |
+
scores[h] = float(composite[i])
|
| 717 |
+
|
| 718 |
+
# Optional ML prior for tie-breaking (Hybrid mode). Does NOT override
|
| 719 |
+
# oracle-chosen winner; only nudges among near-ties.
|
| 720 |
+
ml_probs: Dict[str, float] = {}
|
| 721 |
+
if self._ml_model is not None and self._fe is not None:
|
| 722 |
+
try:
|
| 723 |
+
sim_state = self._sim.get_state_snapshot()
|
| 724 |
+
feats = self._fe.extract_scenario_features(sim_state)
|
| 725 |
+
probs = self._ml_model.predict_proba(feats.reshape(1, -1))[0]
|
| 726 |
+
for i, h in enumerate(self.HEURISTIC_NAMES):
|
| 727 |
+
if i < len(probs):
|
| 728 |
+
ml_probs[h] = float(probs[i])
|
| 729 |
+
except Exception as e:
|
| 730 |
+
logger.debug("ML prior failed (non-fatal): %s", e)
|
| 731 |
+
|
| 732 |
+
# Pick best oracle score; break ties (within 2%) by highest ML probability
|
| 733 |
+
sorted_h = sorted(self.HEURISTIC_NAMES, key=lambda h: scores[h])
|
| 734 |
+
best = sorted_h[0]
|
| 735 |
+
best_score = scores[best]
|
| 736 |
+
if ml_probs:
|
| 737 |
+
tied = [h for h in sorted_h if scores[h] - best_score < 0.02]
|
| 738 |
+
if len(tied) > 1:
|
| 739 |
+
best = max(tied, key=lambda h: ml_probs.get(h, 0.0))
|
| 740 |
+
|
| 741 |
+
switched = best != self._current_heuristic
|
| 742 |
+
self.switching_log.record(
|
| 743 |
+
time=now,
|
| 744 |
+
features=[float(raw[h][0]) for h in self.HEURISTIC_NAMES],
|
| 745 |
+
probabilities={h: round(scores[h], 4) for h in self.HEURISTIC_NAMES},
|
| 746 |
+
selected=best,
|
| 747 |
+
switched=switched,
|
| 748 |
+
reason="oracle_fork" if not ml_probs else "hybrid_oracle",
|
| 749 |
+
confidence=1.0 - best_score, # lower composite → higher confidence
|
| 750 |
+
top_features=[
|
| 751 |
+
{"name": f"oracle_tard_{h}", "value": round(raw[h][0], 2), "importance": 1.0}
|
| 752 |
+
for h in self.HEURISTIC_NAMES
|
| 753 |
+
],
|
| 754 |
+
plain_english=(
|
| 755 |
+
f"Oracle fork: {best} wins next {int(self.HORIZON)}-min horizon "
|
| 756 |
+
f"(composite score {best_score:.3f})."
|
| 757 |
+
),
|
| 758 |
+
)
|
| 759 |
+
self._current_heuristic = best
|
| 760 |
+
|
| 761 |
+
def _apply_starvation_prevention(self, jobs: List[Any], current_time: float) -> List[Any]:
|
| 762 |
+
starving = [j for j in jobs if (current_time - j.arrival_time) > self.STARVATION_LIMIT]
|
| 763 |
+
non_starving = [j for j in jobs if j not in starving]
|
| 764 |
+
return starving + non_starving
|
| 765 |
+
|
| 766 |
+
|
| 767 |
+
# ---------------------------------------------------------------------------
|
| 768 |
+
# Factory helpers
|
| 769 |
+
# ---------------------------------------------------------------------------
|
| 770 |
+
|
| 771 |
+
def load_batchwise_selector(
|
| 772 |
+
model_name: str = "rf",
|
| 773 |
+
feature_extractor: Any = None,
|
| 774 |
+
) -> BatchwiseSelector:
|
| 775 |
+
"""Load a BatchwiseSelector for a given classifier variant.
|
| 776 |
+
|
| 777 |
+
Parameters
|
| 778 |
+
----------
|
| 779 |
+
model_name : str
|
| 780 |
+
One of "dt", "rf", "xgb".
|
| 781 |
+
feature_extractor : FeatureExtractor
|
| 782 |
+
Feature extraction instance.
|
| 783 |
+
"""
|
| 784 |
+
import json
|
| 785 |
+
|
| 786 |
+
if feature_extractor is None:
|
| 787 |
+
from src.features import FeatureExtractor
|
| 788 |
+
feature_extractor = FeatureExtractor()
|
| 789 |
+
|
| 790 |
+
path = MODELS_DIR / f"selector_{model_name}.joblib"
|
| 791 |
+
if not path.exists():
|
| 792 |
+
raise FileNotFoundError(f"Model not found: {path}")
|
| 793 |
+
model = joblib.load(path)
|
| 794 |
+
|
| 795 |
+
model_hash = getattr(model, "_dahs_run_hash", None)
|
| 796 |
+
|
| 797 |
+
# Load feature importances if available
|
| 798 |
+
feature_importances = None
|
| 799 |
+
feature_names = None
|
| 800 |
+
names_meta: Dict[str, Any] = {}
|
| 801 |
+
|
| 802 |
+
try:
|
| 803 |
+
feature_names_path = MODELS_DIR / "feature_names.json"
|
| 804 |
+
if feature_names_path.exists():
|
| 805 |
+
with open(feature_names_path) as f:
|
| 806 |
+
names_data = json.load(f)
|
| 807 |
+
if isinstance(names_data, dict) and "features" in names_data:
|
| 808 |
+
names_meta = names_data.get("_meta", {})
|
| 809 |
+
feature_names = [d["name"] for d in names_data["features"]]
|
| 810 |
+
else:
|
| 811 |
+
feature_names = [d["name"] for d in names_data]
|
| 812 |
+
|
| 813 |
+
if hasattr(model, "feature_importances_"):
|
| 814 |
+
feature_importances = model.feature_importances_
|
| 815 |
+
except Exception as exc:
|
| 816 |
+
logger.warning("Failed to load feature_names.json: %s", exc)
|
| 817 |
+
|
| 818 |
+
# Load feature ranges for OOD detection
|
| 819 |
+
ranges_meta: Dict[str, Any] = {}
|
| 820 |
+
try:
|
| 821 |
+
ranges_path = MODELS_DIR / "feature_ranges.json"
|
| 822 |
+
if ranges_path.exists():
|
| 823 |
+
feature_extractor.load_feature_ranges(ranges_path)
|
| 824 |
+
ranges_meta = getattr(feature_extractor, "_feature_ranges_meta", {}) or {}
|
| 825 |
+
except Exception as exc:
|
| 826 |
+
logger.warning("Failed to load feature_ranges.json: %s", exc)
|
| 827 |
+
|
| 828 |
+
# Validate that all artifacts came from the same training run. Legacy
|
| 829 |
+
# artifacts (model_hash is None) are tolerated for backwards compatibility,
|
| 830 |
+
# but any present-and-disagreeing hashes raise loudly — a mismatch means
|
| 831 |
+
# someone retrained without regenerating sidecars and the OOD guardrail
|
| 832 |
+
# would otherwise apply stale ranges.
|
| 833 |
+
artifact_hashes = {
|
| 834 |
+
"model": model_hash,
|
| 835 |
+
"feature_ranges": ranges_meta.get("run_hash"),
|
| 836 |
+
"feature_names": names_meta.get("run_hash"),
|
| 837 |
+
}
|
| 838 |
+
present = {k: v for k, v in artifact_hashes.items() if v is not None}
|
| 839 |
+
if len(set(present.values())) > 1:
|
| 840 |
+
raise RuntimeError(
|
| 841 |
+
"DAHS model/artifact hash mismatch — re-run scripts/run_pipeline.py "
|
| 842 |
+
f"to regenerate them in lockstep. Hashes: {artifact_hashes}"
|
| 843 |
+
)
|
| 844 |
+
if feature_names is not None and hasattr(model, "n_features_in_"):
|
| 845 |
+
if model.n_features_in_ != len(feature_names):
|
| 846 |
+
raise RuntimeError(
|
| 847 |
+
f"Model expects {model.n_features_in_} features but "
|
| 848 |
+
f"feature_names.json has {len(feature_names)}. Retrain."
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
return BatchwiseSelector(
|
| 852 |
+
model=model,
|
| 853 |
+
feature_extractor=feature_extractor,
|
| 854 |
+
feature_importances=feature_importances,
|
| 855 |
+
feature_names=feature_names,
|
| 856 |
+
)
|
| 857 |
+
|
| 858 |
+
|
| 859 |
+
def load_hybrid_priority(feature_extractor: Any = None) -> HybridPriority:
|
| 860 |
+
"""Load the GBR-based HybridPriority scheduler."""
|
| 861 |
+
if feature_extractor is None:
|
| 862 |
+
from src.features import FeatureExtractor
|
| 863 |
+
feature_extractor = FeatureExtractor()
|
| 864 |
+
path = MODELS_DIR / "priority_gbr.joblib"
|
| 865 |
+
return HybridPriority(model_path=path, feature_extractor=feature_extractor)
|
src/presets.py
ADDED
|
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
presets.py — Static-Solver Comparison Presets for DAHS_2
|
| 3 |
+
|
| 4 |
+
Each preset pins a single classical dispatch rule (FIFO, Priority-EDD, …) that
|
| 5 |
+
runs for the full 600-minute shift. The stress environment is the same realistic,
|
| 6 |
+
literature-calibrated workload used everywhere else in the project:
|
| 7 |
+
|
| 8 |
+
- Time-varying job-type composition (morning Type-A dominant → afternoon bulk
|
| 9 |
+
B/C/D → evening Type-E express surge), simulator._COMPOSITION_PROFILE.
|
| 10 |
+
- Bimodal intraday arrival-rate curve with a lunch dip and an evening peak,
|
| 11 |
+
simulator._SURGE_PROFILE.
|
| 12 |
+
- Per-type processing-time lognormal variability (CV ≈ 30 %) and Poisson
|
| 13 |
+
arrivals, all stochastic.
|
| 14 |
+
|
| 15 |
+
Presets intentionally do **not** override job_type_frequencies: the workload is
|
| 16 |
+
identical across presets and DAHS, so the only experimental variable is the
|
| 17 |
+
dispatch strategy itself. This rules out composition bias as an explanation for
|
| 18 |
+
any performance gap and makes the static-solver-vs-DAHS comparison a clean
|
| 19 |
+
controlled experiment.
|
| 20 |
+
|
| 21 |
+
Presets differ in operational stress parameters (arrival rate, breakdown rate,
|
| 22 |
+
batch size, deadline tightness, processing-time scale) so the static-solver
|
| 23 |
+
comparison is tested across a range of realistic operating regimes.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
from __future__ import annotations
|
| 27 |
+
|
| 28 |
+
import logging
|
| 29 |
+
from dataclasses import dataclass, field
|
| 30 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
HEURISTIC_INDEX = {
|
| 35 |
+
"fifo": 0,
|
| 36 |
+
"priority_edd": 1,
|
| 37 |
+
"critical_ratio": 2,
|
| 38 |
+
"atc": 3,
|
| 39 |
+
"wspt": 4,
|
| 40 |
+
"slack": 5,
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
HEURISTIC_LABELS = ["FIFO", "Priority-EDD", "Critical-Ratio", "ATC", "WSPT", "Slack"]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@dataclass
|
| 47 |
+
class PresetScenario:
|
| 48 |
+
"""A 600-min single-solver scenario used as a static baseline against DAHS.
|
| 49 |
+
|
| 50 |
+
The solver named by ``favored_heuristic`` runs for the entire shift. The
|
| 51 |
+
workload composition is always the realistic time-varying profile embedded
|
| 52 |
+
in the simulator — this preset only configures stress parameters
|
| 53 |
+
(arrival rate, breakdowns, deadline tightness, etc.).
|
| 54 |
+
"""
|
| 55 |
+
name: str
|
| 56 |
+
description: str
|
| 57 |
+
favored_heuristic: str
|
| 58 |
+
favored_heuristic_idx: int
|
| 59 |
+
seed: int
|
| 60 |
+
|
| 61 |
+
base_arrival_rate: float = 2.5
|
| 62 |
+
breakdown_prob: float = 0.003
|
| 63 |
+
batch_arrival_size: int = 30
|
| 64 |
+
lunch_penalty_factor: float = 1.3
|
| 65 |
+
|
| 66 |
+
# Kept for API compatibility. Presets leave this empty so the simulator
|
| 67 |
+
# falls through to its realistic time-varying _COMPOSITION_PROFILE.
|
| 68 |
+
# Setting a non-empty dict here would override the profile and reintroduce
|
| 69 |
+
# composition bias — intentionally avoided.
|
| 70 |
+
job_type_frequencies: Dict[str, float] = field(default_factory=dict)
|
| 71 |
+
due_date_tightness: float = 1.0
|
| 72 |
+
processing_time_scale: float = 1.0
|
| 73 |
+
why_it_favors: str = ""
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
PRESETS: List[PresetScenario] = [
|
| 77 |
+
|
| 78 |
+
# ── Preset 1: FIFO — light, low-disruption baseline ─────────────────────
|
| 79 |
+
PresetScenario(
|
| 80 |
+
name="Preset-1-FIFO",
|
| 81 |
+
description="Light steady flow, no breakdowns, generous deadlines — FIFO runs for the full 600 min",
|
| 82 |
+
favored_heuristic="fifo",
|
| 83 |
+
favored_heuristic_idx=0,
|
| 84 |
+
seed=200_001,
|
| 85 |
+
base_arrival_rate=2.0,
|
| 86 |
+
breakdown_prob=0.0,
|
| 87 |
+
batch_arrival_size=10,
|
| 88 |
+
lunch_penalty_factor=1.0,
|
| 89 |
+
due_date_tightness=2.5,
|
| 90 |
+
processing_time_scale=1.0,
|
| 91 |
+
why_it_favors=(
|
| 92 |
+
"Light load with loose deadlines and no disruptions — a regime where "
|
| 93 |
+
"FIFO's simplicity is hard to beat. Runs on the same realistic "
|
| 94 |
+
"time-varying package mix (A-dominant morning → B/C/D bulk afternoon → "
|
| 95 |
+
"Type-E express evening) as every other arm."
|
| 96 |
+
),
|
| 97 |
+
),
|
| 98 |
+
|
| 99 |
+
# ── Preset 2: Priority-EDD — tight deadlines, frequent express orders ──
|
| 100 |
+
PresetScenario(
|
| 101 |
+
name="Preset-2-Priority-EDD",
|
| 102 |
+
description="Tight deadlines with frequent express orders — Priority-EDD runs for the full 600 min",
|
| 103 |
+
favored_heuristic="priority_edd",
|
| 104 |
+
favored_heuristic_idx=1,
|
| 105 |
+
seed=200_002,
|
| 106 |
+
base_arrival_rate=2.5,
|
| 107 |
+
breakdown_prob=0.001,
|
| 108 |
+
batch_arrival_size=20,
|
| 109 |
+
lunch_penalty_factor=1.1,
|
| 110 |
+
due_date_tightness=0.65,
|
| 111 |
+
processing_time_scale=1.0,
|
| 112 |
+
why_it_favors=(
|
| 113 |
+
"Tight deadlines give Priority-EDD a natural edge: sorting by "
|
| 114 |
+
"(priority class, due date) captures urgency directly. Workload is "
|
| 115 |
+
"the same realistic A→E daily profile — any advantage comes from "
|
| 116 |
+
"the dispatch rule, not from a biased job mix."
|
| 117 |
+
),
|
| 118 |
+
),
|
| 119 |
+
|
| 120 |
+
# ── Preset 3: Critical Ratio — frequent station breakdowns ─────────────
|
| 121 |
+
PresetScenario(
|
| 122 |
+
name="Preset-3-CR",
|
| 123 |
+
description="Frequent station breakdowns on a realistic workload — Critical-Ratio runs for the full 600 min",
|
| 124 |
+
favored_heuristic="critical_ratio",
|
| 125 |
+
favored_heuristic_idx=2,
|
| 126 |
+
seed=200_003,
|
| 127 |
+
base_arrival_rate=2.5,
|
| 128 |
+
breakdown_prob=0.018,
|
| 129 |
+
batch_arrival_size=20,
|
| 130 |
+
lunch_penalty_factor=1.2,
|
| 131 |
+
due_date_tightness=0.85,
|
| 132 |
+
processing_time_scale=1.0,
|
| 133 |
+
why_it_favors=(
|
| 134 |
+
"Frequent breakdowns make static urgency scores go stale. "
|
| 135 |
+
"Critical-Ratio = (due_date − now) / remaining_proc_time is "
|
| 136 |
+
"recomputed every dispatch, so it tracks live time pressure. "
|
| 137 |
+
"The arrival stream is the realistic time-varying one."
|
| 138 |
+
),
|
| 139 |
+
),
|
| 140 |
+
|
| 141 |
+
# ── Preset 4: ATC — heavy load, morning surge ──────────────────────────
|
| 142 |
+
PresetScenario(
|
| 143 |
+
name="Preset-4-ATC",
|
| 144 |
+
description="Heavy sustained load with high-weight jobs — ATC runs for the full 600 min",
|
| 145 |
+
favored_heuristic="atc",
|
| 146 |
+
favored_heuristic_idx=3,
|
| 147 |
+
seed=200_004,
|
| 148 |
+
base_arrival_rate=4.0,
|
| 149 |
+
breakdown_prob=0.003,
|
| 150 |
+
batch_arrival_size=50,
|
| 151 |
+
lunch_penalty_factor=1.4,
|
| 152 |
+
due_date_tightness=0.55,
|
| 153 |
+
processing_time_scale=1.0,
|
| 154 |
+
why_it_favors=(
|
| 155 |
+
"Sustained heavy load needs joint weight–urgency optimisation. "
|
| 156 |
+
"ATC's (w/p)·exp(−slack/K·p̄) closed form is near-optimal for "
|
| 157 |
+
"weighted tardiness under congestion. Workload composition follows "
|
| 158 |
+
"the realistic daily profile — no preset-specific mix."
|
| 159 |
+
),
|
| 160 |
+
),
|
| 161 |
+
|
| 162 |
+
# ── Preset 5: WSPT — short jobs, loose deadlines, throughput focus ─────
|
| 163 |
+
PresetScenario(
|
| 164 |
+
name="Preset-5-WSPT",
|
| 165 |
+
description="Short-jobs-dominate regime with loose deadlines — WSPT runs for the full 600 min",
|
| 166 |
+
favored_heuristic="wspt",
|
| 167 |
+
favored_heuristic_idx=4,
|
| 168 |
+
seed=200_005,
|
| 169 |
+
base_arrival_rate=3.0,
|
| 170 |
+
breakdown_prob=0.001,
|
| 171 |
+
batch_arrival_size=15,
|
| 172 |
+
lunch_penalty_factor=1.0,
|
| 173 |
+
due_date_tightness=2.0,
|
| 174 |
+
processing_time_scale=0.7,
|
| 175 |
+
why_it_favors=(
|
| 176 |
+
"Processing times scaled down 30 % give short jobs on loose deadlines "
|
| 177 |
+
"— the regime where Smith's weighted-shortest-processing-time rule "
|
| 178 |
+
"is provably optimal for minimising weighted flow time. The arrival "
|
| 179 |
+
"composition is the realistic time-varying profile."
|
| 180 |
+
),
|
| 181 |
+
),
|
| 182 |
+
|
| 183 |
+
# ── Preset 6: Slack — recovery mode, very tight deadlines ──────────────
|
| 184 |
+
PresetScenario(
|
| 185 |
+
name="Preset-6-Slack",
|
| 186 |
+
description="Recovery mode with very tight deadlines — Slack runs for the full 600 min",
|
| 187 |
+
favored_heuristic="slack",
|
| 188 |
+
favored_heuristic_idx=5,
|
| 189 |
+
seed=200_006,
|
| 190 |
+
base_arrival_rate=3.5,
|
| 191 |
+
breakdown_prob=0.002,
|
| 192 |
+
batch_arrival_size=60,
|
| 193 |
+
lunch_penalty_factor=1.2,
|
| 194 |
+
due_date_tightness=0.30,
|
| 195 |
+
processing_time_scale=1.2,
|
| 196 |
+
why_it_favors=(
|
| 197 |
+
"Extreme deadline tightness triggers recovery behaviour. Slack "
|
| 198 |
+
"= due_date − now − remaining_proc_time identifies which jobs can "
|
| 199 |
+
"still be saved versus which are already lost. Workload is the "
|
| 200 |
+
"realistic daily profile; stress comes from deadlines and batch size."
|
| 201 |
+
),
|
| 202 |
+
),
|
| 203 |
+
|
| 204 |
+
# ── Preset 7: Real-Data Calibrated (Olist) — stress params only ────────
|
| 205 |
+
PresetScenario(
|
| 206 |
+
name="Preset-7-RealData",
|
| 207 |
+
description=(
|
| 208 |
+
"Stress parameters calibrated from Olist Brazilian E-Commerce "
|
| 209 |
+
"dataset (96,478 real orders, 2016-2018) — WSPT runs for the full 600 min"
|
| 210 |
+
),
|
| 211 |
+
favored_heuristic="wspt",
|
| 212 |
+
favored_heuristic_idx=4,
|
| 213 |
+
seed=200_007,
|
| 214 |
+
# arrival_rate: Olist implies ~9.9 orders/hr; we use 30/hr (0.5/min)
|
| 215 |
+
# representing a mid-scale DC operating at ~20% of peak capacity.
|
| 216 |
+
# Ref: Olist Brazilian E-Commerce Dataset, Kaggle (2018);
|
| 217 |
+
# Published DC range 60-150/hr — Gu et al. (2010) EJOR 203(3):539-549.
|
| 218 |
+
base_arrival_rate=0.5,
|
| 219 |
+
# breakdown_prob: empirical 2-5% of operational hours — Inman (1999)
|
| 220 |
+
breakdown_prob=0.003,
|
| 221 |
+
# batch_arrival_size: calibrated to Olist avg items/order (~1.2 items)
|
| 222 |
+
# scaled to warehouse batch size range — Bartholdi & Hackman (2019)
|
| 223 |
+
batch_arrival_size=15,
|
| 224 |
+
lunch_penalty_factor=1.2,
|
| 225 |
+
# due_date_tightness: derived from Olist SLA/cycle ratio (23.2d / 10.2d = 2.27)
|
| 226 |
+
# mapped to simulator scale: 1.5x gives comparable SLA pressure
|
| 227 |
+
due_date_tightness=1.5,
|
| 228 |
+
processing_time_scale=1.0,
|
| 229 |
+
why_it_favors=(
|
| 230 |
+
"Operational parameters (arrival rate 30/hr, batch size 15, "
|
| 231 |
+
"deadline tightness 1.5×) are calibrated from 96,478 real Olist "
|
| 232 |
+
"orders. Package composition still follows the realistic "
|
| 233 |
+
"time-varying profile so there is no composition bias. WSPT is the "
|
| 234 |
+
"static baseline for this operating regime."
|
| 235 |
+
),
|
| 236 |
+
),
|
| 237 |
+
]
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def get_preset(name: str) -> PresetScenario:
|
| 241 |
+
"""Return a preset by name (case-insensitive match on prefix)."""
|
| 242 |
+
name_lower = name.lower()
|
| 243 |
+
for p in PRESETS:
|
| 244 |
+
if p.name.lower() == name_lower or p.favored_heuristic == name_lower:
|
| 245 |
+
return p
|
| 246 |
+
raise ValueError(
|
| 247 |
+
f"Unknown preset: {name!r}. Available: {[p.name for p in PRESETS]}"
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def get_all_presets() -> List[PresetScenario]:
|
| 252 |
+
"""Return all preset scenario configs."""
|
| 253 |
+
return list(PRESETS)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def run_preset_demo(
|
| 257 |
+
preset: PresetScenario,
|
| 258 |
+
duration: float = 600.0,
|
| 259 |
+
) -> Dict[str, Any]:
|
| 260 |
+
"""Run all 6 baselines + DAHS on a preset, returning full comparison results."""
|
| 261 |
+
from src.heuristics import (
|
| 262 |
+
fifo_dispatch, priority_edd_dispatch, critical_ratio_dispatch,
|
| 263 |
+
atc_dispatch, wspt_dispatch, slack_dispatch,
|
| 264 |
+
)
|
| 265 |
+
from src.simulator import WarehouseSimulator
|
| 266 |
+
from src.features import FeatureExtractor
|
| 267 |
+
|
| 268 |
+
dispatch_map = {
|
| 269 |
+
"fifo": fifo_dispatch,
|
| 270 |
+
"priority_edd": priority_edd_dispatch,
|
| 271 |
+
"critical_ratio": critical_ratio_dispatch,
|
| 272 |
+
"atc": atc_dispatch,
|
| 273 |
+
"wspt": wspt_dispatch,
|
| 274 |
+
"slack": slack_dispatch,
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
sim_kwargs = {
|
| 278 |
+
"base_arrival_rate": preset.base_arrival_rate,
|
| 279 |
+
"breakdown_prob": preset.breakdown_prob,
|
| 280 |
+
"batch_arrival_size": preset.batch_arrival_size,
|
| 281 |
+
"lunch_penalty_factor": preset.lunch_penalty_factor,
|
| 282 |
+
"job_type_frequencies": preset.job_type_frequencies or {},
|
| 283 |
+
"due_date_tightness": preset.due_date_tightness,
|
| 284 |
+
"processing_time_scale": preset.processing_time_scale,
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
results: Dict[str, Any] = {}
|
| 288 |
+
|
| 289 |
+
for heur_name, heur_fn in dispatch_map.items():
|
| 290 |
+
fe = FeatureExtractor()
|
| 291 |
+
sim = WarehouseSimulator(seed=preset.seed, heuristic_fn=heur_fn, feature_extractor=fe, **sim_kwargs)
|
| 292 |
+
metrics = sim.run(duration=duration)
|
| 293 |
+
results[heur_name] = metrics
|
| 294 |
+
logger.info(
|
| 295 |
+
"[%s] %s: tardiness=%.1f, sla=%.3f, throughput=%.2f",
|
| 296 |
+
preset.name, heur_name, metrics.total_tardiness, metrics.sla_breach_rate, metrics.throughput,
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
import numpy as np
|
| 300 |
+
tardy = np.array([results[h].total_tardiness for h in dispatch_map])
|
| 301 |
+
sla = np.array([results[h].sla_breach_rate for h in dispatch_map])
|
| 302 |
+
cyc = np.array([results[h].avg_cycle_time for h in dispatch_map])
|
| 303 |
+
|
| 304 |
+
def _norm(arr):
|
| 305 |
+
r = arr.max() - arr.min()
|
| 306 |
+
return np.zeros_like(arr) if r == 0 else (arr - arr.min()) / r
|
| 307 |
+
|
| 308 |
+
scores = 0.40 * _norm(tardy) + 0.35 * _norm(sla) + 0.25 * _norm(cyc)
|
| 309 |
+
best_idx = int(np.argmin(scores))
|
| 310 |
+
winner = list(dispatch_map.keys())[best_idx]
|
| 311 |
+
|
| 312 |
+
logger.info("[%s] Empirical winner: %s (expected: %s) — %s",
|
| 313 |
+
preset.name, winner, preset.favored_heuristic,
|
| 314 |
+
"CORRECT" if winner == preset.favored_heuristic else "UNEXPECTED")
|
| 315 |
+
|
| 316 |
+
# Try running DAHS if models are available
|
| 317 |
+
dahs_selected = None
|
| 318 |
+
switching_log = None
|
| 319 |
+
|
| 320 |
+
try:
|
| 321 |
+
from src.hybrid_scheduler import BatchwiseSelector, MODELS_DIR
|
| 322 |
+
from pathlib import Path as _Path
|
| 323 |
+
model_path = _Path(MODELS_DIR) / "selector_rf.joblib"
|
| 324 |
+
if model_path.exists():
|
| 325 |
+
import joblib
|
| 326 |
+
model = joblib.load(model_path)
|
| 327 |
+
fe = FeatureExtractor()
|
| 328 |
+
selector = BatchwiseSelector(model=model, feature_extractor=fe)
|
| 329 |
+
|
| 330 |
+
dahs_sim = WarehouseSimulator(
|
| 331 |
+
seed=preset.seed,
|
| 332 |
+
heuristic_fn=fifo_dispatch,
|
| 333 |
+
feature_extractor=fe,
|
| 334 |
+
**sim_kwargs,
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
def dahs_dispatch(jobs, t, zone_id):
|
| 338 |
+
selector.update_state(dahs_sim.get_state_snapshot())
|
| 339 |
+
return selector.dispatch(jobs, t, zone_id)
|
| 340 |
+
|
| 341 |
+
dahs_sim.heuristic_fn = dahs_dispatch
|
| 342 |
+
dahs_metrics = dahs_sim.run(duration=duration)
|
| 343 |
+
results["dahs"] = dahs_metrics
|
| 344 |
+
switching_log = selector.switching_log
|
| 345 |
+
|
| 346 |
+
dist: Dict[str, int] = {}
|
| 347 |
+
for e in switching_log.entries:
|
| 348 |
+
h = e["selected"]
|
| 349 |
+
dist[h] = dist.get(h, 0) + 1
|
| 350 |
+
dahs_selected = max(dist, key=dist.get) if dist else None
|
| 351 |
+
except Exception as exc:
|
| 352 |
+
logger.warning("[%s] DAHS run skipped: %s", preset.name, exc)
|
| 353 |
+
|
| 354 |
+
return {
|
| 355 |
+
"preset": {
|
| 356 |
+
"name": preset.name,
|
| 357 |
+
"favored_heuristic": preset.favored_heuristic,
|
| 358 |
+
"seed": preset.seed,
|
| 359 |
+
"why_it_favors": preset.why_it_favors,
|
| 360 |
+
},
|
| 361 |
+
"results": results,
|
| 362 |
+
"scores": {h: float(s) for h, s in zip(dispatch_map.keys(), scores)},
|
| 363 |
+
"winner": winner,
|
| 364 |
+
"correct": winner == preset.favored_heuristic,
|
| 365 |
+
"dahs_selected": dahs_selected,
|
| 366 |
+
"switching_log": switching_log,
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
def run_all_preset_demos(duration: float = 600.0) -> List[Dict[str, Any]]:
|
| 371 |
+
"""Run all preset demos and print a summary table."""
|
| 372 |
+
all_results = []
|
| 373 |
+
print("\n" + "=" * 72)
|
| 374 |
+
print(" DAHS_2 PRESET PROOF-OF-CONCEPT EVALUATION")
|
| 375 |
+
print("=" * 72)
|
| 376 |
+
print(f" {'Preset':<26} {'Expected':>14} {'Empirical Winner':>17} {'Match':>6} {'DAHS Pick':>12}")
|
| 377 |
+
print("-" * 72)
|
| 378 |
+
|
| 379 |
+
for preset in PRESETS:
|
| 380 |
+
result = run_preset_demo(preset, duration=duration)
|
| 381 |
+
all_results.append(result)
|
| 382 |
+
|
| 383 |
+
match_str = "OK" if result["correct"] else "--"
|
| 384 |
+
dahs_str = result["dahs_selected"] or "N/A"
|
| 385 |
+
print(f" {preset.name:<26} {preset.favored_heuristic:>14} "
|
| 386 |
+
f"{result['winner']:>17} {match_str:>6} {dahs_str:>12}")
|
| 387 |
+
|
| 388 |
+
n_correct = sum(1 for r in all_results if r["correct"])
|
| 389 |
+
print("-" * 72)
|
| 390 |
+
print(f" Presets where empirical winner = expected: {n_correct}/{len(PRESETS)}")
|
| 391 |
+
print("=" * 72 + "\n")
|
| 392 |
+
|
| 393 |
+
return all_results
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
if __name__ == "__main__":
|
| 397 |
+
import logging as _logging
|
| 398 |
+
_logging.basicConfig(level=_logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
| 399 |
+
run_all_preset_demos()
|
src/references.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
references.py — Centralized Academic Bibliography for DAHS_2
|
| 3 |
+
|
| 4 |
+
All academic references used in the DAHS_2 project are collected here.
|
| 5 |
+
This serves two purposes:
|
| 6 |
+
1. Backend can serve them via GET /api/references for the frontend.
|
| 7 |
+
2. Acts as a single-source-of-truth bibliography for the project.
|
| 8 |
+
|
| 9 |
+
Usage:
|
| 10 |
+
from src.references import REFERENCES
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
REFERENCES = [
|
| 14 |
+
{
|
| 15 |
+
"key": "dekoster2007",
|
| 16 |
+
"authors": "De Koster, R., Le-Duc, T., & Roodbergen, K.J.",
|
| 17 |
+
"year": 2007,
|
| 18 |
+
"title": "Design and control of warehouse order picking: A literature review",
|
| 19 |
+
"journal": "European Journal of Operational Research",
|
| 20 |
+
"volume": "182(2)",
|
| 21 |
+
"pages": "481-501",
|
| 22 |
+
"doi": "10.1016/j.ejor.2006.07.009",
|
| 23 |
+
"used_for": "Zone structure, processing time variability (CV ~30%), worker utilization targets",
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"key": "gu2010",
|
| 27 |
+
"authors": "Gu, J., Goetschalckx, M., & McGinnis, L.F.",
|
| 28 |
+
"year": 2010,
|
| 29 |
+
"title": "Research on warehouse design and performance evaluation: A comprehensive review",
|
| 30 |
+
"journal": "European Journal of Operational Research",
|
| 31 |
+
"volume": "203(3)",
|
| 32 |
+
"pages": "539-549",
|
| 33 |
+
"doi": "10.1016/j.ejor.2009.07.031",
|
| 34 |
+
"used_for": "Arrival rates (60-150 orders/hr for mid-scale DCs), facility sizing, performance benchmarks",
|
| 35 |
+
},
|
| 36 |
+
{
|
| 37 |
+
"key": "tompkins2010",
|
| 38 |
+
"authors": "Tompkins, J.A., White, J.A., Bozer, Y.A., & Tanchoco, J.M.A.",
|
| 39 |
+
"year": 2010,
|
| 40 |
+
"title": "Facilities Planning",
|
| 41 |
+
"journal": "Wiley (4th edition)",
|
| 42 |
+
"volume": None,
|
| 43 |
+
"pages": None,
|
| 44 |
+
"doi": None,
|
| 45 |
+
"used_for": "Processing time ranges for warehouse picking and packing operations",
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"key": "bartholdi2019",
|
| 49 |
+
"authors": "Bartholdi, J.J. & Hackman, S.T.",
|
| 50 |
+
"year": 2019,
|
| 51 |
+
"title": "Warehouse & Distribution Science",
|
| 52 |
+
"journal": "Georgia Institute of Technology (Release 0.98.1)",
|
| 53 |
+
"volume": None,
|
| 54 |
+
"pages": None,
|
| 55 |
+
"doi": None,
|
| 56 |
+
"used_for": "Batch arrival sizes (20-60 items/truck), receiving/shipping dock operations",
|
| 57 |
+
},
|
| 58 |
+
{
|
| 59 |
+
"key": "inman1999",
|
| 60 |
+
"authors": "Inman, R.R.",
|
| 61 |
+
"year": 1999,
|
| 62 |
+
"title": "Are you implementing a pull system by putting the cart before the horse?",
|
| 63 |
+
"journal": "Production and Inventory Management Journal",
|
| 64 |
+
"volume": "40(2)",
|
| 65 |
+
"pages": "67-71",
|
| 66 |
+
"doi": None,
|
| 67 |
+
"used_for": "Equipment breakdown rates (2-5% of operational hours) in warehouse environments",
|
| 68 |
+
},
|
| 69 |
+
{
|
| 70 |
+
"key": "goetschalckx1989",
|
| 71 |
+
"authors": "Goetschalckx, M. & Ashayeri, J.",
|
| 72 |
+
"year": 1989,
|
| 73 |
+
"title": "Classification and design of order picking systems",
|
| 74 |
+
"journal": "Logistics World",
|
| 75 |
+
"volume": "2(2)",
|
| 76 |
+
"pages": "99-106",
|
| 77 |
+
"doi": None,
|
| 78 |
+
"used_for": "Mean time to repair (MTTR) for conveyor/AGV equipment (10-30 min)",
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"key": "frazelle2016",
|
| 82 |
+
"authors": "Frazelle, E.H.",
|
| 83 |
+
"year": 2016,
|
| 84 |
+
"title": "World-Class Warehousing and Material Handling",
|
| 85 |
+
"journal": "McGraw-Hill (2nd edition)",
|
| 86 |
+
"volume": None,
|
| 87 |
+
"pages": None,
|
| 88 |
+
"doi": None,
|
| 89 |
+
"used_for": "Worker utilization benchmarks (65-85%), SLA breach norms for e-commerce fulfillment",
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"key": "garg2017",
|
| 93 |
+
"authors": "Garg, D., Swami, M., & Bhagat, B.",
|
| 94 |
+
"year": 2017,
|
| 95 |
+
"title": "Impact of breaks on productivity and ergonomics in warehouse operations",
|
| 96 |
+
"journal": "International Journal of Industrial Engineering",
|
| 97 |
+
"volume": "24(3)",
|
| 98 |
+
"pages": "181-192",
|
| 99 |
+
"doi": None,
|
| 100 |
+
"used_for": "Lunch productivity penalty factor (20-40% drop); calibrated to 1.3x (30%)",
|
| 101 |
+
},
|
| 102 |
+
{
|
| 103 |
+
"key": "vepsalainen1987",
|
| 104 |
+
"authors": "Vepsalainen, A.P.J. & Morton, T.E.",
|
| 105 |
+
"year": 1987,
|
| 106 |
+
"title": "Priority rules for job shops with weighted tardiness costs",
|
| 107 |
+
"journal": "Management Science",
|
| 108 |
+
"volume": "33(8)",
|
| 109 |
+
"pages": "1035-1047",
|
| 110 |
+
"doi": "10.1287/mnsc.33.8.1035",
|
| 111 |
+
"used_for": "ATC (Apparent Tardiness Cost) dispatch rule formulation and K-factor selection",
|
| 112 |
+
},
|
| 113 |
+
{
|
| 114 |
+
"key": "smith1956",
|
| 115 |
+
"authors": "Smith, W.E.",
|
| 116 |
+
"year": 1956,
|
| 117 |
+
"title": "Various optimizers for single-stage production",
|
| 118 |
+
"journal": "Naval Research Logistics Quarterly",
|
| 119 |
+
"volume": "3(1-2)",
|
| 120 |
+
"pages": "59-66",
|
| 121 |
+
"doi": "10.1002/nav.3800030106",
|
| 122 |
+
"used_for": "WSPT dispatch rule (optimal for weighted completion time on single machine)",
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"key": "pinedo2016",
|
| 126 |
+
"authors": "Pinedo, M.L.",
|
| 127 |
+
"year": 2016,
|
| 128 |
+
"title": "Scheduling: Theory, Algorithms, and Systems",
|
| 129 |
+
"journal": "Springer (5th edition)",
|
| 130 |
+
"volume": None,
|
| 131 |
+
"pages": None,
|
| 132 |
+
"doi": "10.1007/978-3-319-26580-3",
|
| 133 |
+
"used_for": "JSSP formulation, dispatch rule taxonomy (EDD, Slack, CR), critical ratio rule",
|
| 134 |
+
},
|
| 135 |
+
{
|
| 136 |
+
"key": "burke2013",
|
| 137 |
+
"authors": "Burke, E.K., Gendreau, M., Hyde, M., et al.",
|
| 138 |
+
"year": 2013,
|
| 139 |
+
"title": "Hyper-heuristics: A survey of the state of the art",
|
| 140 |
+
"journal": "Journal of the Operational Research Society",
|
| 141 |
+
"volume": "64(12)",
|
| 142 |
+
"pages": "1695-1724",
|
| 143 |
+
"doi": "10.1057/jors.2013.71",
|
| 144 |
+
"used_for": "Hyper-heuristic framework: selection vs generation hyper-heuristics",
|
| 145 |
+
},
|
| 146 |
+
{
|
| 147 |
+
"key": "cowling2001",
|
| 148 |
+
"authors": "Cowling, P., Kendall, G., & Soubeiga, E.",
|
| 149 |
+
"year": 2001,
|
| 150 |
+
"title": "A hyperheuristic approach to scheduling a sales summit",
|
| 151 |
+
"journal": "PATAT 2000, LNCS 2079",
|
| 152 |
+
"volume": None,
|
| 153 |
+
"pages": "176-190",
|
| 154 |
+
"doi": None,
|
| 155 |
+
"used_for": "Pioneering work on adaptive heuristic selection for scheduling problems",
|
| 156 |
+
},
|
| 157 |
+
{
|
| 158 |
+
"key": "demsar2006",
|
| 159 |
+
"authors": "Demsar, J.",
|
| 160 |
+
"year": 2006,
|
| 161 |
+
"title": "Statistical comparisons of classifiers over multiple data sets",
|
| 162 |
+
"journal": "Journal of Machine Learning Research",
|
| 163 |
+
"volume": "7",
|
| 164 |
+
"pages": "1-30",
|
| 165 |
+
"doi": None,
|
| 166 |
+
"used_for": "Friedman test + Nemenyi post-hoc for multi-classifier comparison methodology",
|
| 167 |
+
},
|
| 168 |
+
{
|
| 169 |
+
"key": "lundberg2017",
|
| 170 |
+
"authors": "Lundberg, S.M. & Lee, S.I.",
|
| 171 |
+
"year": 2017,
|
| 172 |
+
"title": "A unified approach to interpreting model predictions",
|
| 173 |
+
"journal": "Advances in Neural Information Processing Systems (NeurIPS 2017)",
|
| 174 |
+
"volume": "30",
|
| 175 |
+
"pages": "4765-4774",
|
| 176 |
+
"doi": None,
|
| 177 |
+
"used_for": "SHAP values for feature attribution in ML interpretability",
|
| 178 |
+
},
|
| 179 |
+
]
|
src/simulator.py
ADDED
|
@@ -0,0 +1,1302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
simulator.py — Discrete-Event Warehouse Simulation Engine (DAHS_2)
|
| 3 |
+
|
| 4 |
+
Implements a realistic e-commerce fulfillment warehouse with 8 zones,
|
| 5 |
+
37 stations, 5 job types, stochastic disruptions, and pluggable heuristics.
|
| 6 |
+
|
| 7 |
+
NEW in DAHS_2:
|
| 8 |
+
- save_state() -> dict — snapshot full simulation state for fork training
|
| 9 |
+
- from_state(state_dict, heuristic_fn) -> WarehouseSimulator (classmethod)
|
| 10 |
+
- get_partial_metrics(since_time) -> SimulationMetrics — for 20-min fork windows
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import copy
|
| 16 |
+
import logging
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import simpy
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# ---------------------------------------------------------------------------
|
| 27 |
+
# Data Structures
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class ZoneConfig:
|
| 32 |
+
"""Configuration for a single warehouse zone."""
|
| 33 |
+
zone_id: int
|
| 34 |
+
name: str
|
| 35 |
+
num_stations: int
|
| 36 |
+
zone_type: str # e.g. "receiving", "picking", "packing", "shipping"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@dataclass
|
| 40 |
+
class JobType:
|
| 41 |
+
"""Specification for a category of warehouse jobs."""
|
| 42 |
+
name: str # "A" – "E"
|
| 43 |
+
route: List[int] # ordered zone IDs
|
| 44 |
+
proc_time_ranges: List[Tuple[float, float]] # (min, max) minutes per zone
|
| 45 |
+
due_date_offset: float # minutes from arrival to due date
|
| 46 |
+
frequency: float # relative arrival weight
|
| 47 |
+
priority_weight: float # higher = more important
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@dataclass
|
| 51 |
+
class Operation:
|
| 52 |
+
"""One processing step of a job at a specific zone/station."""
|
| 53 |
+
zone_id: int
|
| 54 |
+
nominal_proc_time: float
|
| 55 |
+
actual_proc_time: float = 0.0
|
| 56 |
+
start_time: float = -1.0
|
| 57 |
+
end_time: float = -1.0
|
| 58 |
+
station_id: int = -1
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@dataclass
|
| 62 |
+
class Job:
|
| 63 |
+
"""A single warehouse order moving through the system."""
|
| 64 |
+
job_id: int
|
| 65 |
+
job_type: str
|
| 66 |
+
arrival_time: float
|
| 67 |
+
due_date: float
|
| 68 |
+
operations: List[Operation]
|
| 69 |
+
current_op_idx: int = 0
|
| 70 |
+
priority: int = 1 # 1=standard, 2=expedited, 3=VIP
|
| 71 |
+
status: str = "waiting" # waiting / processing / done / late
|
| 72 |
+
completion_time: float = -1.0
|
| 73 |
+
priority_escalated: bool = False
|
| 74 |
+
|
| 75 |
+
@property
|
| 76 |
+
def is_complete(self) -> bool:
|
| 77 |
+
return self.current_op_idx >= len(self.operations)
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
def next_zone_id(self) -> Optional[int]:
|
| 81 |
+
if self.is_complete:
|
| 82 |
+
return None
|
| 83 |
+
return self.operations[self.current_op_idx].zone_id
|
| 84 |
+
|
| 85 |
+
def remaining_proc_time(self) -> float:
|
| 86 |
+
"""Sum of nominal proc times for all remaining operations."""
|
| 87 |
+
return sum(op.nominal_proc_time for op in self.operations[self.current_op_idx:])
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@dataclass
|
| 91 |
+
class StationState:
|
| 92 |
+
"""Runtime state of a single processing station."""
|
| 93 |
+
station_id: int
|
| 94 |
+
zone_id: int
|
| 95 |
+
is_broken: bool = False
|
| 96 |
+
repair_end_time: float = 0.0
|
| 97 |
+
current_job: Optional[int] = None # job_id or None
|
| 98 |
+
busy_until: float = 0.0
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@dataclass
|
| 102 |
+
class SimulationMetrics:
|
| 103 |
+
"""All performance metrics from one simulation run."""
|
| 104 |
+
makespan: float = 0.0
|
| 105 |
+
total_tardiness: float = 0.0
|
| 106 |
+
sla_breach_rate: float = 0.0
|
| 107 |
+
avg_cycle_time: float = 0.0
|
| 108 |
+
zone_utilization: Dict[int, float] = field(default_factory=dict)
|
| 109 |
+
throughput: float = 0.0
|
| 110 |
+
queue_max: int = 0
|
| 111 |
+
queue_history: List[Tuple[float, Dict[int, int]]] = field(default_factory=list)
|
| 112 |
+
completed_jobs: int = 0
|
| 113 |
+
total_jobs: int = 0
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# ---------------------------------------------------------------------------
|
| 117 |
+
# Simulator
|
| 118 |
+
# ---------------------------------------------------------------------------
|
| 119 |
+
|
| 120 |
+
class WarehouseSimulator:
|
| 121 |
+
"""
|
| 122 |
+
SimPy-based discrete-event simulator for an e-commerce fulfillment center.
|
| 123 |
+
|
| 124 |
+
Simulation parameters are calibrated to published warehouse operations research:
|
| 125 |
+
|
| 126 |
+
- Zone structure & station counts (37 total, 8 zones):
|
| 127 |
+
De Koster et al. (2007), EJOR 182(2):481-501 — 20-50 stations typical for
|
| 128 |
+
mid-scale distribution centers.
|
| 129 |
+
Gu et al. (2010), EJOR 203(3):539-549 — warehouse design benchmarks.
|
| 130 |
+
|
| 131 |
+
- Arrival rate (BASE_ARRIVAL_RATE = 1.5 jobs/min = 90/hr):
|
| 132 |
+
Gu et al. (2010) — 60-150 orders/hour for mid-scale DCs.
|
| 133 |
+
(Default constructor arg is 2.5, calibrated preset uses 1.5.)
|
| 134 |
+
|
| 135 |
+
- Processing time ranges (Picking 5-18 min, Receiving 3-8 min):
|
| 136 |
+
Tompkins et al. (2010), Facilities Planning, Wiley 4th ed.
|
| 137 |
+
Bartholdi & Hackman (2019), Warehouse & Distribution Science, GT.
|
| 138 |
+
|
| 139 |
+
- Breakdown frequency (BREAKDOWN_PROB = 0.003):
|
| 140 |
+
Inman (1999), Prod. & Inv. Mgmt. Journal 40(2):67-71 — 2-5% of
|
| 141 |
+
operational hours. 0.003/min × 37 stations × 600 min ≈ 2.7% exposure.
|
| 142 |
+
|
| 143 |
+
- Repair time mean (18 min, Exponential):
|
| 144 |
+
Goetschalckx & Ashayeri (1989) — 10-30 min MTTR for conveyor/AGV.
|
| 145 |
+
|
| 146 |
+
- Batch arrival size (30 jobs, every 45 min):
|
| 147 |
+
Bartholdi & Hackman (2019) — 20-60 items per truck unload;
|
| 148 |
+
30-60 min between truck docks for mid-scale DC.
|
| 149 |
+
|
| 150 |
+
- Processing time variability (lognormal σ = 0.30, CV ≈ 30%):
|
| 151 |
+
De Koster et al. (2007) — CV of 20-35% for manual warehouse operations.
|
| 152 |
+
|
| 153 |
+
- Lunch productivity penalty (1.3×, 30% slowdown):
|
| 154 |
+
Garg et al. (2017), Int. J. Industrial Engineering 24(3):181-192 —
|
| 155 |
+
20-40% productivity drop during scheduled breaks.
|
| 156 |
+
|
| 157 |
+
- Worker utilization target (implicit 65-80%):
|
| 158 |
+
Frazelle (2016), World-Class Warehousing, McGraw-Hill 2nd ed.
|
| 159 |
+
|
| 160 |
+
- Due date SLA windows (60-320 min, spanning 1-5.3 hours):
|
| 161 |
+
Industry standard SLA windows of 1-8 hours for e-commerce fulfillment.
|
| 162 |
+
Frazelle (2016) — 2-10% SLA breach acceptable in well-run warehouses.
|
| 163 |
+
|
| 164 |
+
Parameters
|
| 165 |
+
----------
|
| 166 |
+
seed : int
|
| 167 |
+
Random seed for full reproducibility.
|
| 168 |
+
heuristic_fn : Callable
|
| 169 |
+
Dispatch function: (jobs, current_time, zone_id) -> ordered List[Job].
|
| 170 |
+
feature_extractor : optional
|
| 171 |
+
FeatureExtractor instance used when running in hybrid-ML mode.
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
# Zone configuration: 8 zones with station counts summing to 37
|
| 175 |
+
# Total 37 stations within published 20-50 range for mid-scale DCs
|
| 176 |
+
# Ref: De Koster et al. (2007), EJOR 182(2):481-501
|
| 177 |
+
# Ref: Gu et al. (2010), EJOR 203(3):539-549
|
| 178 |
+
ZONE_SPECS: List[Tuple[int, str, int, str]] = [
|
| 179 |
+
(0, "Receiving", 3, "receiving"),
|
| 180 |
+
(1, "Sorting", 4, "sorting"),
|
| 181 |
+
(2, "Picking-A", 6, "picking"),
|
| 182 |
+
(3, "Picking-B", 8, "picking"),
|
| 183 |
+
(4, "Value-Add", 5, "value_add"),
|
| 184 |
+
(5, "QC", 4, "quality"),
|
| 185 |
+
(6, "Packing", 3, "packing"),
|
| 186 |
+
(7, "Shipping", 4, "shipping"),
|
| 187 |
+
]
|
| 188 |
+
|
| 189 |
+
# Job-type definitions (name, route, proc_time_ranges, due_date_offset_min, freq, prio_weight)
|
| 190 |
+
# Processing time ranges (min, max) in minutes:
|
| 191 |
+
# Receiving ops (3-8 min): Bartholdi & Hackman (2019) — upper-end realistic with inspection
|
| 192 |
+
# Picking ops (5-18 min): Tompkins et al. (2010), Facilities Planning — 2-15 min/order
|
| 193 |
+
# Value-Add (8-18 min): Tompkins et al. (2010) — extended operations
|
| 194 |
+
# Due date offsets (60-320 min, spanning 1-5.3 hours):
|
| 195 |
+
# Ref: Frazelle (2016) — typical SLA windows 1-8 hours for e-commerce fulfillment
|
| 196 |
+
JOB_TYPE_SPECS = [
|
| 197 |
+
("A", [0, 1, 2, 6, 7], [(3,8),(2,5),(5,12),(4,9),(2,4)], 120, 0.25, 2.0),
|
| 198 |
+
("B", [0, 1, 3, 5, 6, 7], [(3,8),(2,5),(6,14),(3,7),(4,9),(2,4)], 160, 0.30, 1.5),
|
| 199 |
+
("C", [0, 1, 4, 5, 6, 7], [(3,8),(2,5),(8,18),(3,7),(4,9),(2,4)], 240, 0.20, 1.0),
|
| 200 |
+
("D", [0, 1, 2, 4, 5, 6, 7], [(3,8),(2,5),(5,12),(8,18),(3,7),(4,9),(2,4)], 320, 0.15, 0.8),
|
| 201 |
+
("E", [1, 3, 7], [(2,5),(4,10),(1,3)], 60, 0.10, 3.0), # express — tight SLA
|
| 202 |
+
]
|
| 203 |
+
|
| 204 |
+
# Base arrival rate: 2.5 jobs/min = 150/hr (peak); calibrated preset uses 1.5 (90/hr = mid-scale)
|
| 205 |
+
# Published range: 60-150 orders/hour for mid-scale distribution centers
|
| 206 |
+
# Ref: Gu et al. (2010), EJOR 203(3):539-549
|
| 207 |
+
BASE_ARRIVAL_RATE = 2.5 # jobs per minute
|
| 208 |
+
SIM_DURATION = 600.0 # minutes (one 10-hour shift)
|
| 209 |
+
|
| 210 |
+
def __init__(
|
| 211 |
+
self,
|
| 212 |
+
seed: int,
|
| 213 |
+
heuristic_fn: Callable,
|
| 214 |
+
feature_extractor=None,
|
| 215 |
+
# breakdown_prob: 0.003/min ≈ 2.7% exposure over 600 min × 37 stations
|
| 216 |
+
# Published range: 2-5% of operational hours — Inman (1999)
|
| 217 |
+
base_arrival_rate: float = 2.5,
|
| 218 |
+
breakdown_prob: float = 0.003,
|
| 219 |
+
# batch_arrival_size: 30 items per truck — within published 20-60 range
|
| 220 |
+
# Ref: Bartholdi & Hackman (2019), Warehouse & Distribution Science
|
| 221 |
+
batch_arrival_size: int = 30,
|
| 222 |
+
# lunch_penalty_factor: 1.3x = 30% productivity drop during break
|
| 223 |
+
# Published range: 20-40% — Garg et al. (2017), Int. J. Industrial Engineering
|
| 224 |
+
lunch_penalty_factor: float = 1.3,
|
| 225 |
+
# Preset overrides — leave empty/1.0 for default behavior
|
| 226 |
+
job_type_frequencies: Optional[Dict[str, float]] = None,
|
| 227 |
+
due_date_tightness: float = 1.0,
|
| 228 |
+
processing_time_scale: float = 1.0,
|
| 229 |
+
) -> None:
|
| 230 |
+
self.seed = seed
|
| 231 |
+
self.heuristic_fn = heuristic_fn
|
| 232 |
+
self.feature_extractor = feature_extractor
|
| 233 |
+
self._base_arrival_rate = base_arrival_rate
|
| 234 |
+
self._breakdown_prob = breakdown_prob
|
| 235 |
+
self._batch_arrival_size = batch_arrival_size
|
| 236 |
+
self._lunch_penalty_factor = lunch_penalty_factor
|
| 237 |
+
self._job_type_frequencies = job_type_frequencies or {}
|
| 238 |
+
self._due_date_tightness = due_date_tightness
|
| 239 |
+
self._processing_time_scale = processing_time_scale
|
| 240 |
+
|
| 241 |
+
# Validate preset frequency overrides sum to ~1.0
|
| 242 |
+
if self._job_type_frequencies:
|
| 243 |
+
total = sum(self._job_type_frequencies.values())
|
| 244 |
+
if total > 0 and abs(total - 1.0) > 0.01:
|
| 245 |
+
logger.warning("job_type_frequencies sum=%.3f (expected ~1.0)", total)
|
| 246 |
+
|
| 247 |
+
self.rng = np.random.default_rng(seed)
|
| 248 |
+
|
| 249 |
+
self.env = simpy.Environment()
|
| 250 |
+
|
| 251 |
+
self.zones: Dict[int, ZoneConfig] = {}
|
| 252 |
+
self.job_types: Dict[str, JobType] = {}
|
| 253 |
+
self.stations: Dict[int, StationState] = {}
|
| 254 |
+
self.station_resources: Dict[int, simpy.Resource] = {}
|
| 255 |
+
|
| 256 |
+
# Zone-level queues (list of Job)
|
| 257 |
+
self.zone_queues: Dict[int, List[Job]] = {}
|
| 258 |
+
|
| 259 |
+
# Job registry
|
| 260 |
+
self.all_jobs: Dict[int, Job] = {}
|
| 261 |
+
self.completed_jobs: List[Job] = []
|
| 262 |
+
self._job_counter = 0
|
| 263 |
+
|
| 264 |
+
# Metrics tracking
|
| 265 |
+
self._zone_busy_time: Dict[int, float] = {}
|
| 266 |
+
self._queue_snapshots: List[Tuple[float, Dict[int, int]]] = []
|
| 267 |
+
self._max_queue: int = 0
|
| 268 |
+
self._lunch_active: bool = False
|
| 269 |
+
|
| 270 |
+
self._setup_zones()
|
| 271 |
+
self._setup_job_types()
|
| 272 |
+
|
| 273 |
+
# ------------------------------------------------------------------
|
| 274 |
+
# Setup helpers
|
| 275 |
+
# ------------------------------------------------------------------
|
| 276 |
+
|
| 277 |
+
def _setup_zones(self) -> None:
|
| 278 |
+
station_id = 0
|
| 279 |
+
self.dispatcher_triggers = {}
|
| 280 |
+
for zone_id, name, n_stations, zone_type in self.ZONE_SPECS:
|
| 281 |
+
self.zones[zone_id] = ZoneConfig(zone_id, name, n_stations, zone_type)
|
| 282 |
+
self.zone_queues[zone_id] = []
|
| 283 |
+
self.dispatcher_triggers[zone_id] = self.env.event()
|
| 284 |
+
self._zone_busy_time[zone_id] = 0.0
|
| 285 |
+
for _ in range(n_stations):
|
| 286 |
+
st = StationState(station_id=station_id, zone_id=zone_id)
|
| 287 |
+
self.stations[station_id] = st
|
| 288 |
+
self.station_resources[station_id] = simpy.Resource(self.env, capacity=1)
|
| 289 |
+
station_id += 1
|
| 290 |
+
|
| 291 |
+
def _setup_job_types(self) -> None:
|
| 292 |
+
for name, route, proc_ranges, due_offset, freq, prio_w in self.JOB_TYPE_SPECS:
|
| 293 |
+
effective_freq = self._job_type_frequencies.get(name, freq) if self._job_type_frequencies else freq
|
| 294 |
+
effective_due = due_offset * self._due_date_tightness
|
| 295 |
+
scaled_ranges = [
|
| 296 |
+
(lo * self._processing_time_scale, hi * self._processing_time_scale)
|
| 297 |
+
for lo, hi in proc_ranges
|
| 298 |
+
]
|
| 299 |
+
self.job_types[name] = JobType(
|
| 300 |
+
name=name,
|
| 301 |
+
route=route,
|
| 302 |
+
proc_time_ranges=scaled_ranges,
|
| 303 |
+
due_date_offset=effective_due,
|
| 304 |
+
frequency=effective_freq,
|
| 305 |
+
priority_weight=prio_w,
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
# ------------------------------------------------------------------
|
| 309 |
+
# Utility
|
| 310 |
+
# ------------------------------------------------------------------
|
| 311 |
+
|
| 312 |
+
def _next_job_id(self) -> int:
|
| 313 |
+
jid = self._job_counter
|
| 314 |
+
self._job_counter += 1
|
| 315 |
+
return jid
|
| 316 |
+
|
| 317 |
+
# Time-varying composition profile — reflects realistic daily order-mix shifts
|
| 318 |
+
# observed in e-commerce fulfillment centres:
|
| 319 |
+
# morning (0-120 min): overnight standard-order backlog → Type A dominant
|
| 320 |
+
# mid-morning (120-240): diversifying mix — bulk Type B/C joins the floor
|
| 321 |
+
# afternoon (240-420): heavy bulk (C, D) as truck deliveries concentrate
|
| 322 |
+
# evening peak (420-600): same-day cut-off surge — Type E express dominates
|
| 323 |
+
# Values are anchor points; _get_composition_profile interpolates linearly
|
| 324 |
+
# between them so the distribution shifts smoothly rather than in hard steps.
|
| 325 |
+
# Refs: Bartholdi & Hackman (2019) §6; De Koster et al. (2007) EJOR 182(2);
|
| 326 |
+
# Boysen et al. (2019) EJOR 277(2):396-411 — e-commerce warehousing patterns.
|
| 327 |
+
_COMPOSITION_PROFILE = [
|
| 328 |
+
(0.0, {"A": 0.55, "B": 0.18, "C": 0.10, "D": 0.09, "E": 0.08}),
|
| 329 |
+
(120.0, {"A": 0.45, "B": 0.22, "C": 0.13, "D": 0.10, "E": 0.10}),
|
| 330 |
+
(240.0, {"A": 0.25, "B": 0.32, "C": 0.20, "D": 0.13, "E": 0.10}),
|
| 331 |
+
(360.0, {"A": 0.15, "B": 0.25, "C": 0.30, "D": 0.20, "E": 0.10}),
|
| 332 |
+
(480.0, {"A": 0.12, "B": 0.18, "C": 0.22, "D": 0.13, "E": 0.35}),
|
| 333 |
+
(600.0, {"A": 0.10, "B": 0.14, "C": 0.12, "D": 0.08, "E": 0.56}),
|
| 334 |
+
]
|
| 335 |
+
|
| 336 |
+
# Composition noise: Gaussian perturbation σ applied per component, then
|
| 337 |
+
# renormalised to sum to 1. Keeps the profile from being artificially smooth
|
| 338 |
+
# while preserving the overall daily trend. Low enough (σ=0.03) that no single
|
| 339 |
+
# solver is accidentally favoured by random fluctuations.
|
| 340 |
+
_COMPOSITION_NOISE_SIGMA = 0.03
|
| 341 |
+
|
| 342 |
+
# Intraday arrival-rate multiplier anchors (time in minutes from shift start).
|
| 343 |
+
# Bimodal curve with a mild morning plateau, lunch dip, and a strong evening
|
| 344 |
+
# peak reflecting the same-day cut-off surge that is characteristic of
|
| 345 |
+
# e-commerce fulfilment centres. Values are interpolated linearly between
|
| 346 |
+
# anchors and a small multiplicative noise band is applied per sample.
|
| 347 |
+
# Refs: Boysen et al. (2019) EJOR 277(2); Bartholdi & Hackman (2019) §2.3;
|
| 348 |
+
# De Koster et al. (2007) EJOR 182(2) — workload profiles in DCs.
|
| 349 |
+
_SURGE_PROFILE = [
|
| 350 |
+
(0.0, 0.55), # shift start — overnight backlog, still warming up
|
| 351 |
+
(60.0, 0.95), # morning ramp complete
|
| 352 |
+
(120.0, 1.05), # morning baseline
|
| 353 |
+
(180.0, 1.15), # pre-lunch mild peak
|
| 354 |
+
(240.0, 0.60), # lunch dip (productivity drop)
|
| 355 |
+
(300.0, 0.95), # post-lunch recovery
|
| 356 |
+
(360.0, 1.20), # afternoon ramp
|
| 357 |
+
(420.0, 1.45), # approaching evening peak
|
| 358 |
+
(480.0, 1.65), # evening peak — same-day cut-off surge
|
| 359 |
+
(540.0, 1.50), # late evening (still elevated)
|
| 360 |
+
(600.0, 1.30), # shift close (slight taper)
|
| 361 |
+
]
|
| 362 |
+
# Multiplicative noise band applied per surge evaluation; keeps arrivals
|
| 363 |
+
# stochastic without systematically biasing any heuristic.
|
| 364 |
+
_SURGE_NOISE_LO = 0.93
|
| 365 |
+
_SURGE_NOISE_HI = 1.07
|
| 366 |
+
|
| 367 |
+
def _get_composition_profile(self, t: float) -> Dict[str, float]:
|
| 368 |
+
"""Per-type probability vector at time t.
|
| 369 |
+
|
| 370 |
+
If the caller supplied explicit ``job_type_frequencies`` (used by
|
| 371 |
+
calibration tests and heuristic-biased presets) those are returned
|
| 372 |
+
verbatim. Otherwise the profile is **linearly interpolated** between the
|
| 373 |
+
anchor points in ``_COMPOSITION_PROFILE`` and a small Gaussian noise
|
| 374 |
+
term is added so the distribution is not artificially deterministic.
|
| 375 |
+
The noisy vector is clipped to be non-negative and renormalised to 1.
|
| 376 |
+
"""
|
| 377 |
+
if self._job_type_frequencies:
|
| 378 |
+
return dict(self._job_type_frequencies)
|
| 379 |
+
|
| 380 |
+
types = ("A", "B", "C", "D", "E")
|
| 381 |
+
|
| 382 |
+
# Find the two anchor points bracketing t
|
| 383 |
+
anchors = self._COMPOSITION_PROFILE
|
| 384 |
+
if t <= anchors[0][0]:
|
| 385 |
+
base = anchors[0][1]
|
| 386 |
+
elif t >= anchors[-1][0]:
|
| 387 |
+
base = anchors[-1][1]
|
| 388 |
+
else:
|
| 389 |
+
base = anchors[0][1]
|
| 390 |
+
for (t_a, p_a), (t_b, p_b) in zip(anchors[:-1], anchors[1:]):
|
| 391 |
+
if t_a <= t < t_b:
|
| 392 |
+
alpha = (t - t_a) / max(t_b - t_a, 1e-9)
|
| 393 |
+
base = {k: (1 - alpha) * p_a[k] + alpha * p_b[k] for k in types}
|
| 394 |
+
break
|
| 395 |
+
|
| 396 |
+
# Stochastic perturbation for realism (seeded via self.rng).
|
| 397 |
+
if self._COMPOSITION_NOISE_SIGMA > 0:
|
| 398 |
+
noisy = {
|
| 399 |
+
k: max(0.0, base[k] + float(self.rng.normal(0.0, self._COMPOSITION_NOISE_SIGMA)))
|
| 400 |
+
for k in types
|
| 401 |
+
}
|
| 402 |
+
total = sum(noisy.values())
|
| 403 |
+
if total > 0:
|
| 404 |
+
return {k: v / total for k, v in noisy.items()}
|
| 405 |
+
return dict(base)
|
| 406 |
+
|
| 407 |
+
def _sample_job_type(self) -> str:
|
| 408 |
+
profile = self._get_composition_profile(self.env.now)
|
| 409 |
+
types = list(self.job_types.keys())
|
| 410 |
+
weights = [profile.get(t, self.job_types[t].frequency) for t in types]
|
| 411 |
+
total = sum(weights)
|
| 412 |
+
if total <= 0:
|
| 413 |
+
weights = [self.job_types[t].frequency for t in types]
|
| 414 |
+
total = sum(weights)
|
| 415 |
+
probs = [w / total for w in weights]
|
| 416 |
+
return self.rng.choice(types, p=probs)
|
| 417 |
+
|
| 418 |
+
def _create_job(self, job_type_name: str, arrival_time: float) -> Job:
|
| 419 |
+
jt = self.job_types[job_type_name]
|
| 420 |
+
operations = []
|
| 421 |
+
for zone_id, (lo, hi) in zip(jt.route, jt.proc_time_ranges):
|
| 422 |
+
nominal = float(self.rng.uniform(lo, hi))
|
| 423 |
+
operations.append(Operation(zone_id=zone_id, nominal_proc_time=nominal))
|
| 424 |
+
return Job(
|
| 425 |
+
job_id=self._next_job_id(),
|
| 426 |
+
job_type=job_type_name,
|
| 427 |
+
arrival_time=arrival_time,
|
| 428 |
+
due_date=arrival_time + jt.due_date_offset,
|
| 429 |
+
operations=operations,
|
| 430 |
+
priority=3 if job_type_name == "E" else 1,
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
def _surge_base_rate(self, current_time: float) -> float:
|
| 434 |
+
"""Deterministic trend value of the surge multiplier at time ``t``.
|
| 435 |
+
|
| 436 |
+
Pure anchor-point interpolation — no RNG calls, so this is safe to
|
| 437 |
+
invoke from informational paths (state snapshots, feature extraction)
|
| 438 |
+
without disturbing the arrival-process sample stream.
|
| 439 |
+
"""
|
| 440 |
+
anchors = self._SURGE_PROFILE
|
| 441 |
+
if current_time <= anchors[0][0]:
|
| 442 |
+
return float(anchors[0][1])
|
| 443 |
+
if current_time >= anchors[-1][0]:
|
| 444 |
+
return float(anchors[-1][1])
|
| 445 |
+
for (t_a, v_a), (t_b, v_b) in zip(anchors[:-1], anchors[1:]):
|
| 446 |
+
if t_a <= current_time < t_b:
|
| 447 |
+
alpha = (current_time - t_a) / max(t_b - t_a, 1e-9)
|
| 448 |
+
return float((1.0 - alpha) * v_a + alpha * v_b)
|
| 449 |
+
return float(anchors[-1][1])
|
| 450 |
+
|
| 451 |
+
def _get_surge_multiplier(self, current_time: float) -> float:
|
| 452 |
+
"""Time-of-day arrival-rate multiplier (t in minutes from shift start).
|
| 453 |
+
|
| 454 |
+
The curve is a linear interpolation between the anchor points in
|
| 455 |
+
``_SURGE_PROFILE`` plus a small multiplicative noise term drawn from
|
| 456 |
+
``U(_SURGE_NOISE_LO, _SURGE_NOISE_HI)`` — so the instantaneous rate is
|
| 457 |
+
both deterministically trended (bimodal with evening peak) and
|
| 458 |
+
stochastically perturbed each time the process samples an arrival.
|
| 459 |
+
Returns a strictly positive multiplier.
|
| 460 |
+
"""
|
| 461 |
+
base = self._surge_base_rate(current_time)
|
| 462 |
+
noise = float(self.rng.uniform(self._SURGE_NOISE_LO, self._SURGE_NOISE_HI))
|
| 463 |
+
return max(0.05, base * noise)
|
| 464 |
+
|
| 465 |
+
def _record_queue_snapshot(self) -> None:
|
| 466 |
+
snapshot = {z: len(q) for z, q in self.zone_queues.items()}
|
| 467 |
+
self._queue_snapshots.append((self.env.now, snapshot))
|
| 468 |
+
total = sum(snapshot.values())
|
| 469 |
+
if total > self._max_queue:
|
| 470 |
+
self._max_queue = total
|
| 471 |
+
|
| 472 |
+
# ------------------------------------------------------------------
|
| 473 |
+
# SimPy processes
|
| 474 |
+
# ------------------------------------------------------------------
|
| 475 |
+
|
| 476 |
+
def _arrival_process(self):
|
| 477 |
+
"""Continuous Poisson arrival of individual jobs."""
|
| 478 |
+
while True:
|
| 479 |
+
surge = self._get_surge_multiplier(self.env.now)
|
| 480 |
+
rate = self._base_arrival_rate * surge
|
| 481 |
+
inter_arrival = float(self.rng.exponential(1.0 / rate))
|
| 482 |
+
yield self.env.timeout(inter_arrival)
|
| 483 |
+
|
| 484 |
+
jt_name = self._sample_job_type()
|
| 485 |
+
job = self._create_job(jt_name, self.env.now)
|
| 486 |
+
self.all_jobs[job.job_id] = job
|
| 487 |
+
self.env.process(self._process_job(job))
|
| 488 |
+
|
| 489 |
+
def _batch_arrival_process(self):
|
| 490 |
+
"""Truck arrival every 45 min delivering configurable batch of orders.
|
| 491 |
+
|
| 492 |
+
Interval: 30-60 min between truck docks is typical for mid-scale DCs.
|
| 493 |
+
Batch size: 20-60 items per truck unload.
|
| 494 |
+
Ref: Bartholdi & Hackman (2019), Warehouse & Distribution Science.
|
| 495 |
+
"""
|
| 496 |
+
while True:
|
| 497 |
+
yield self.env.timeout(45.0) # 45 min interval — within 30-60 min published range
|
| 498 |
+
half = max(1, self._batch_arrival_size // 2)
|
| 499 |
+
batch_size = int(self.rng.integers(half, self._batch_arrival_size + 1))
|
| 500 |
+
for _ in range(batch_size):
|
| 501 |
+
jt_name = self._sample_job_type()
|
| 502 |
+
job = self._create_job(jt_name, self.env.now)
|
| 503 |
+
self.all_jobs[job.job_id] = job
|
| 504 |
+
self.env.process(self._process_job(job))
|
| 505 |
+
|
| 506 |
+
def _station_breakdown_process(self, station: StationState):
|
| 507 |
+
"""Per-station breakdown process; rate and repair time are configurable.
|
| 508 |
+
|
| 509 |
+
BREAKDOWN_PROB = 0.003/min: at 37 stations × 600 min, expected total
|
| 510 |
+
breakdown exposure ≈ 2.7%, within published 2-5% range.
|
| 511 |
+
Ref: Inman (1999), Prod. & Inv. Mgmt. Journal 40(2):67-71.
|
| 512 |
+
|
| 513 |
+
Repair time mean = 18 min (Exponential): within 10-30 min MTTR for
|
| 514 |
+
conveyor/AGV equipment in warehouse environments.
|
| 515 |
+
Ref: Goetschalckx & Ashayeri (1989), Logistics World 2(2):99-106.
|
| 516 |
+
"""
|
| 517 |
+
while True:
|
| 518 |
+
ttf = float(self.rng.exponential(1.0 / max(self._breakdown_prob, 1e-9)))
|
| 519 |
+
yield self.env.timeout(ttf)
|
| 520 |
+
station.is_broken = True
|
| 521 |
+
repair_time = float(self.rng.exponential(18.0)) # mean 18 min MTTR
|
| 522 |
+
station.repair_end_time = self.env.now + repair_time
|
| 523 |
+
yield self.env.timeout(repair_time)
|
| 524 |
+
station.is_broken = False
|
| 525 |
+
self._trigger_dispatcher(station.zone_id)
|
| 526 |
+
|
| 527 |
+
def _lunch_break_process(self):
|
| 528 |
+
"""Lunch break from t=300 to t=360 (13:00-14:00)."""
|
| 529 |
+
yield self.env.timeout(300.0)
|
| 530 |
+
self._lunch_active = True
|
| 531 |
+
yield self.env.timeout(60.0)
|
| 532 |
+
self._lunch_active = False
|
| 533 |
+
|
| 534 |
+
def _priority_escalation_process(self):
|
| 535 |
+
"""Every 5 minutes, escalate 5% of standard waiting jobs."""
|
| 536 |
+
while True:
|
| 537 |
+
yield self.env.timeout(5.0)
|
| 538 |
+
waiting = [
|
| 539 |
+
j for j in self.all_jobs.values()
|
| 540 |
+
if j.status == "waiting" and j.priority == 1 and not j.priority_escalated
|
| 541 |
+
]
|
| 542 |
+
n_escalate = max(0, int(len(waiting) * 0.05))
|
| 543 |
+
if n_escalate:
|
| 544 |
+
chosen = self.rng.choice(len(waiting), size=n_escalate, replace=False)
|
| 545 |
+
for idx in chosen:
|
| 546 |
+
waiting[idx].priority = 2
|
| 547 |
+
waiting[idx].priority_escalated = True
|
| 548 |
+
|
| 549 |
+
def _snapshot_process(self):
|
| 550 |
+
"""Record queue depths every 5 minutes."""
|
| 551 |
+
while True:
|
| 552 |
+
self._record_queue_snapshot()
|
| 553 |
+
yield self.env.timeout(5.0)
|
| 554 |
+
|
| 555 |
+
# ------------------------------------------------------------------
|
| 556 |
+
# Job processing
|
| 557 |
+
# ------------------------------------------------------------------
|
| 558 |
+
|
| 559 |
+
def _process_job(self, job: Job):
|
| 560 |
+
"""Route a job through all its operations sequentially."""
|
| 561 |
+
for op_idx, op in enumerate(job.operations):
|
| 562 |
+
zone_id = op.zone_id
|
| 563 |
+
self.zone_queues[zone_id].append(job)
|
| 564 |
+
job.status = "waiting"
|
| 565 |
+
|
| 566 |
+
job._dispatch_event = self.env.event()
|
| 567 |
+
self._trigger_dispatcher(zone_id)
|
| 568 |
+
yield job._dispatch_event
|
| 569 |
+
|
| 570 |
+
station_id = self._pick_station(zone_id)
|
| 571 |
+
op.station_id = station_id
|
| 572 |
+
resource = self.station_resources[station_id]
|
| 573 |
+
st = self.stations[station_id]
|
| 574 |
+
st.current_job = job.job_id
|
| 575 |
+
|
| 576 |
+
with resource.request() as req:
|
| 577 |
+
yield req
|
| 578 |
+
# Re-check breakdown: station may have broken while job was queued.
|
| 579 |
+
while st.is_broken:
|
| 580 |
+
wait_time = max(0.1, st.repair_end_time - self.env.now)
|
| 581 |
+
yield self.env.timeout(wait_time)
|
| 582 |
+
|
| 583 |
+
job.status = "processing"
|
| 584 |
+
job.current_op_idx = op_idx
|
| 585 |
+
|
| 586 |
+
# Lognormal sigma = 0.30 → CV ≈ 30%, within published 20-35% range
|
| 587 |
+
# Ref: De Koster et al. (2007), EJOR 182(2):481-501
|
| 588 |
+
variability = float(self.rng.lognormal(0, 0.30))
|
| 589 |
+
lunch_penalty = self._lunch_penalty_factor if self._lunch_active else 1.0
|
| 590 |
+
actual_time = op.nominal_proc_time * variability * lunch_penalty
|
| 591 |
+
|
| 592 |
+
op.actual_proc_time = actual_time
|
| 593 |
+
op.start_time = self.env.now
|
| 594 |
+
self._zone_busy_time[zone_id] = (
|
| 595 |
+
self._zone_busy_time.get(zone_id, 0.0) + actual_time
|
| 596 |
+
)
|
| 597 |
+
|
| 598 |
+
yield self.env.timeout(actual_time)
|
| 599 |
+
|
| 600 |
+
op.end_time = self.env.now
|
| 601 |
+
st.busy_until = self.env.now
|
| 602 |
+
st.current_job = None
|
| 603 |
+
|
| 604 |
+
self._trigger_dispatcher(zone_id)
|
| 605 |
+
|
| 606 |
+
# Job fully processed
|
| 607 |
+
job.status = "done"
|
| 608 |
+
job.completion_time = self.env.now
|
| 609 |
+
job.current_op_idx = len(job.operations)
|
| 610 |
+
self.completed_jobs.append(job)
|
| 611 |
+
|
| 612 |
+
def _trigger_dispatcher(self, zone_id: int):
|
| 613 |
+
"""Wake up the zone dispatcher if it's idle."""
|
| 614 |
+
if not self.dispatcher_triggers[zone_id].triggered:
|
| 615 |
+
self.dispatcher_triggers[zone_id].succeed()
|
| 616 |
+
|
| 617 |
+
def _zone_dispatcher(self, zone_id: int):
|
| 618 |
+
"""Centralized dispatcher process for a zone."""
|
| 619 |
+
while True:
|
| 620 |
+
yield self.dispatcher_triggers[zone_id]
|
| 621 |
+
self.dispatcher_triggers[zone_id] = self.env.event()
|
| 622 |
+
|
| 623 |
+
while True:
|
| 624 |
+
queue = self.zone_queues[zone_id]
|
| 625 |
+
if not queue:
|
| 626 |
+
break
|
| 627 |
+
|
| 628 |
+
free_stations = [
|
| 629 |
+
sid for sid, st in self.stations.items()
|
| 630 |
+
if st.zone_id == zone_id and not st.is_broken
|
| 631 |
+
and self.station_resources[sid].count + len(self.station_resources[sid].queue) == 0
|
| 632 |
+
]
|
| 633 |
+
|
| 634 |
+
if not free_stations:
|
| 635 |
+
break
|
| 636 |
+
|
| 637 |
+
ordered = self.heuristic_fn(queue, self.env.now, zone_id)
|
| 638 |
+
best_job = ordered[0]
|
| 639 |
+
queue.remove(best_job)
|
| 640 |
+
|
| 641 |
+
best_job._dispatch_event.succeed()
|
| 642 |
+
yield self.env.timeout(0)
|
| 643 |
+
|
| 644 |
+
def _pick_station(self, zone_id: int) -> int:
|
| 645 |
+
"""Pick a free non-broken station, else fallback to least-busy."""
|
| 646 |
+
free_stations = [
|
| 647 |
+
sid for sid, st in self.stations.items()
|
| 648 |
+
if st.zone_id == zone_id and not st.is_broken
|
| 649 |
+
and self.station_resources[sid].count + len(self.station_resources[sid].queue) == 0
|
| 650 |
+
]
|
| 651 |
+
if free_stations:
|
| 652 |
+
return free_stations[0]
|
| 653 |
+
|
| 654 |
+
zone_stations = [
|
| 655 |
+
sid for sid, st in self.stations.items()
|
| 656 |
+
if st.zone_id == zone_id and not st.is_broken
|
| 657 |
+
]
|
| 658 |
+
if not zone_stations:
|
| 659 |
+
zone_stations = [sid for sid, st in self.stations.items() if st.zone_id == zone_id]
|
| 660 |
+
return min(zone_stations, key=lambda sid: self.stations[sid].busy_until)
|
| 661 |
+
|
| 662 |
+
# ------------------------------------------------------------------
|
| 663 |
+
# Streaming API (for WebSocket backend)
|
| 664 |
+
# ------------------------------------------------------------------
|
| 665 |
+
|
| 666 |
+
def init(self) -> None:
|
| 667 |
+
"""Set up all SimPy processes without running. Call step_to() to advance."""
|
| 668 |
+
self._lunch_active = False
|
| 669 |
+
self._processes_registered = True
|
| 670 |
+
self.env.process(self._arrival_process())
|
| 671 |
+
self.env.process(self._batch_arrival_process())
|
| 672 |
+
self.env.process(self._priority_escalation_process())
|
| 673 |
+
self.env.process(self._lunch_break_process())
|
| 674 |
+
self.env.process(self._snapshot_process())
|
| 675 |
+
for zone_id in self.zones:
|
| 676 |
+
self.env.process(self._zone_dispatcher(zone_id))
|
| 677 |
+
for station in self.stations.values():
|
| 678 |
+
self.env.process(self._station_breakdown_process(station))
|
| 679 |
+
|
| 680 |
+
def step_to(self, t: float) -> None:
|
| 681 |
+
"""Advance simulation to time t (must have called init() first)."""
|
| 682 |
+
self.env.run(until=t)
|
| 683 |
+
|
| 684 |
+
def get_visual_snapshot(self) -> Dict[str, Any]:
|
| 685 |
+
"""Return the current visual state for the frontend canvas."""
|
| 686 |
+
now = self.env.now
|
| 687 |
+
completed = self.completed_jobs
|
| 688 |
+
n = len(completed)
|
| 689 |
+
|
| 690 |
+
total_tard = sum(max(0.0, j.completion_time - j.due_date) for j in completed)
|
| 691 |
+
n_late = sum(1 for j in completed if j.completion_time > j.due_date)
|
| 692 |
+
sla = n_late / n if n else 0.0
|
| 693 |
+
avg_cycle = (sum(j.completion_time - j.arrival_time for j in completed) / n
|
| 694 |
+
if n else 0.0)
|
| 695 |
+
throughput = (n / max(now, 0.001)) * 60.0
|
| 696 |
+
|
| 697 |
+
active_jobs: List[Dict[str, Any]] = []
|
| 698 |
+
for zone_id, queue in self.zone_queues.items():
|
| 699 |
+
for job in queue:
|
| 700 |
+
active_jobs.append({
|
| 701 |
+
"id": job.job_id, "type": job.job_type,
|
| 702 |
+
"zoneId": zone_id, "status": "waiting",
|
| 703 |
+
"priority": job.priority,
|
| 704 |
+
})
|
| 705 |
+
|
| 706 |
+
for job in self.all_jobs.values():
|
| 707 |
+
if job.status == "processing" and job.current_op_idx < len(job.operations):
|
| 708 |
+
active_jobs.append({
|
| 709 |
+
"id": job.job_id, "type": job.job_type,
|
| 710 |
+
"zoneId": job.operations[job.current_op_idx].zone_id,
|
| 711 |
+
"status": "processing",
|
| 712 |
+
"priority": job.priority,
|
| 713 |
+
})
|
| 714 |
+
|
| 715 |
+
active_jobs = active_jobs[:50]
|
| 716 |
+
|
| 717 |
+
zone_active = [
|
| 718 |
+
sum(1 for j in self.all_jobs.values()
|
| 719 |
+
if j.status == "processing"
|
| 720 |
+
and j.current_op_idx < len(j.operations)
|
| 721 |
+
and j.operations[j.current_op_idx].zone_id == z)
|
| 722 |
+
for z in range(8)
|
| 723 |
+
]
|
| 724 |
+
|
| 725 |
+
return {
|
| 726 |
+
"time": round(now, 2),
|
| 727 |
+
"activeJobs": active_jobs,
|
| 728 |
+
"zoneQueueLengths": [len(self.zone_queues.get(z, [])) for z in range(8)],
|
| 729 |
+
"zoneActiveCounts": zone_active,
|
| 730 |
+
"metrics": {
|
| 731 |
+
"completed": n,
|
| 732 |
+
"completedJobs": n,
|
| 733 |
+
"totalTardiness": round(total_tard, 1),
|
| 734 |
+
"slaBreachRate": round(sla, 4),
|
| 735 |
+
"avgCycleTime": round(avg_cycle, 2),
|
| 736 |
+
"throughput": round(throughput, 2),
|
| 737 |
+
"jobsPerHour": round(throughput, 2),
|
| 738 |
+
},
|
| 739 |
+
}
|
| 740 |
+
|
| 741 |
+
# ------------------------------------------------------------------
|
| 742 |
+
# Run (batch mode)
|
| 743 |
+
# ------------------------------------------------------------------
|
| 744 |
+
|
| 745 |
+
def run(self, duration: float = 600.0) -> SimulationMetrics:
|
| 746 |
+
"""Execute a full shift simulation and return performance metrics."""
|
| 747 |
+
if not hasattr(self, "_processes_registered") or not self._processes_registered:
|
| 748 |
+
self.init()
|
| 749 |
+
|
| 750 |
+
self.env.run(until=duration)
|
| 751 |
+
|
| 752 |
+
return self._compute_metrics(duration)
|
| 753 |
+
|
| 754 |
+
def _compute_metrics(self, duration: float) -> SimulationMetrics:
|
| 755 |
+
"""Calculate all 7 performance metrics from the completed simulation."""
|
| 756 |
+
completed = self.completed_jobs
|
| 757 |
+
total_jobs = len(self.all_jobs)
|
| 758 |
+
n_completed = len(completed)
|
| 759 |
+
|
| 760 |
+
if not completed:
|
| 761 |
+
return SimulationMetrics(
|
| 762 |
+
makespan=duration,
|
| 763 |
+
zone_utilization={z: 0.0 for z in self.zones},
|
| 764 |
+
queue_history=self._queue_snapshots,
|
| 765 |
+
)
|
| 766 |
+
|
| 767 |
+
makespan = max((j.completion_time for j in completed), default=duration)
|
| 768 |
+
|
| 769 |
+
total_tardiness = sum(
|
| 770 |
+
max(0.0, j.completion_time - j.due_date) for j in completed
|
| 771 |
+
)
|
| 772 |
+
|
| 773 |
+
n_late = sum(1 for j in completed if j.completion_time > j.due_date)
|
| 774 |
+
sla_breach_rate = n_late / n_completed if n_completed else 0.0
|
| 775 |
+
|
| 776 |
+
avg_cycle_time = float(np.mean(
|
| 777 |
+
[j.completion_time - j.arrival_time for j in completed]
|
| 778 |
+
)) if completed else 0.0
|
| 779 |
+
|
| 780 |
+
zone_utilization = {}
|
| 781 |
+
for zone_id, zone in self.zones.items():
|
| 782 |
+
busy = self._zone_busy_time.get(zone_id, 0.0)
|
| 783 |
+
capacity = zone.num_stations * duration
|
| 784 |
+
zone_utilization[zone_id] = min(1.0, busy / capacity) if capacity > 0 else 0.0
|
| 785 |
+
|
| 786 |
+
throughput = (n_completed / duration) * 60.0
|
| 787 |
+
|
| 788 |
+
queue_max = self._max_queue
|
| 789 |
+
|
| 790 |
+
return SimulationMetrics(
|
| 791 |
+
makespan=makespan,
|
| 792 |
+
total_tardiness=total_tardiness,
|
| 793 |
+
sla_breach_rate=sla_breach_rate,
|
| 794 |
+
avg_cycle_time=avg_cycle_time,
|
| 795 |
+
zone_utilization=zone_utilization,
|
| 796 |
+
throughput=throughput,
|
| 797 |
+
queue_max=queue_max,
|
| 798 |
+
queue_history=self._queue_snapshots,
|
| 799 |
+
completed_jobs=n_completed,
|
| 800 |
+
total_jobs=total_jobs,
|
| 801 |
+
)
|
| 802 |
+
|
| 803 |
+
def get_state_snapshot(self) -> Dict[str, Any]:
|
| 804 |
+
"""Return current system state for feature extraction."""
|
| 805 |
+
now = self.env.now
|
| 806 |
+
n_broken = sum(1 for st in self.stations.values() if st.is_broken)
|
| 807 |
+
queue_sizes = {z: len(q) for z, q in self.zone_queues.items()}
|
| 808 |
+
waiting_jobs = [j for j in self.all_jobs.values() if j.status == "waiting"]
|
| 809 |
+
|
| 810 |
+
return {
|
| 811 |
+
"current_time": now,
|
| 812 |
+
"n_orders_in_system": len(waiting_jobs) + sum(
|
| 813 |
+
1 for j in self.all_jobs.values() if j.status == "processing"
|
| 814 |
+
),
|
| 815 |
+
"n_express_orders": sum(1 for j in waiting_jobs if j.job_type == "E"),
|
| 816 |
+
"queue_sizes": queue_sizes,
|
| 817 |
+
"zone_utilization": {
|
| 818 |
+
z: min(1.0, self._zone_busy_time.get(z, 0.0) / max(1.0, now * self.zones[z].num_stations))
|
| 819 |
+
for z in self.zones
|
| 820 |
+
},
|
| 821 |
+
"n_broken_stations": n_broken,
|
| 822 |
+
"lunch_active": self._lunch_active,
|
| 823 |
+
"surge_multiplier": self._surge_base_rate(now),
|
| 824 |
+
"completed_so_far": len(self.completed_jobs),
|
| 825 |
+
"waiting_jobs": waiting_jobs,
|
| 826 |
+
"completed_jobs": self.completed_jobs,
|
| 827 |
+
"all_jobs": self.all_jobs,
|
| 828 |
+
"zones": self.zones,
|
| 829 |
+
"stations": self.stations,
|
| 830 |
+
}
|
| 831 |
+
|
| 832 |
+
# ------------------------------------------------------------------
|
| 833 |
+
# NEW in DAHS_2: State save/restore for snapshot-fork training
|
| 834 |
+
# ------------------------------------------------------------------
|
| 835 |
+
|
| 836 |
+
@staticmethod
|
| 837 |
+
def _serialize_job(job: Job) -> Dict[str, Any]:
|
| 838 |
+
"""Convert a Job to a plain dict (avoids deepcopy of SimPy events)."""
|
| 839 |
+
return {
|
| 840 |
+
"job_id": job.job_id,
|
| 841 |
+
"job_type": job.job_type,
|
| 842 |
+
"arrival_time": job.arrival_time,
|
| 843 |
+
"due_date": job.due_date,
|
| 844 |
+
"operations": [
|
| 845 |
+
{
|
| 846 |
+
"zone_id": op.zone_id,
|
| 847 |
+
"nominal_proc_time": op.nominal_proc_time,
|
| 848 |
+
"actual_proc_time": op.actual_proc_time,
|
| 849 |
+
"start_time": op.start_time,
|
| 850 |
+
"end_time": op.end_time,
|
| 851 |
+
"station_id": op.station_id,
|
| 852 |
+
}
|
| 853 |
+
for op in job.operations
|
| 854 |
+
],
|
| 855 |
+
"current_op_idx": job.current_op_idx,
|
| 856 |
+
"priority": job.priority,
|
| 857 |
+
"status": job.status,
|
| 858 |
+
"completion_time": job.completion_time,
|
| 859 |
+
"priority_escalated": job.priority_escalated,
|
| 860 |
+
}
|
| 861 |
+
|
| 862 |
+
@staticmethod
|
| 863 |
+
def _deserialize_job(d: Dict[str, Any]) -> Job:
|
| 864 |
+
"""Reconstruct a Job from a plain dict."""
|
| 865 |
+
ops = [
|
| 866 |
+
Operation(
|
| 867 |
+
zone_id=o["zone_id"],
|
| 868 |
+
nominal_proc_time=o["nominal_proc_time"],
|
| 869 |
+
actual_proc_time=o["actual_proc_time"],
|
| 870 |
+
start_time=o["start_time"],
|
| 871 |
+
end_time=o["end_time"],
|
| 872 |
+
station_id=o["station_id"],
|
| 873 |
+
)
|
| 874 |
+
for o in d["operations"]
|
| 875 |
+
]
|
| 876 |
+
job = Job(
|
| 877 |
+
job_id=d["job_id"],
|
| 878 |
+
job_type=d["job_type"],
|
| 879 |
+
arrival_time=d["arrival_time"],
|
| 880 |
+
due_date=d["due_date"],
|
| 881 |
+
operations=ops,
|
| 882 |
+
current_op_idx=d["current_op_idx"],
|
| 883 |
+
priority=d["priority"],
|
| 884 |
+
status=d["status"],
|
| 885 |
+
completion_time=d["completion_time"],
|
| 886 |
+
priority_escalated=d["priority_escalated"],
|
| 887 |
+
)
|
| 888 |
+
return job
|
| 889 |
+
|
| 890 |
+
def save_state(self) -> Dict[str, Any]:
|
| 891 |
+
"""Capture complete simulation state for snapshot-fork training.
|
| 892 |
+
|
| 893 |
+
Returns a pickling-safe dict (no SimPy objects) containing:
|
| 894 |
+
- env.now (current time)
|
| 895 |
+
- Serialized jobs, completed_jobs, zone_queues (as job IDs)
|
| 896 |
+
- All station states (is_broken, repair_end_time, current_job, busy_until)
|
| 897 |
+
- RNG state via rng.bit_generator.state
|
| 898 |
+
- _job_counter, _zone_busy_time, _lunch_active, queue snapshot history
|
| 899 |
+
|
| 900 |
+
NOTE: The from_state() classmethod creates a fresh SimPy environment and
|
| 901 |
+
re-initializes processes from the saved data point.
|
| 902 |
+
"""
|
| 903 |
+
state = {
|
| 904 |
+
"env_time": self.env.now,
|
| 905 |
+
"seed": self.seed,
|
| 906 |
+
"_job_counter": self._job_counter,
|
| 907 |
+
"_max_queue": self._max_queue,
|
| 908 |
+
"_lunch_active": self._lunch_active,
|
| 909 |
+
"_zone_busy_time": dict(self._zone_busy_time),
|
| 910 |
+
"_queue_snapshots": list(self._queue_snapshots),
|
| 911 |
+
"rng_state": self.rng.bit_generator.state,
|
| 912 |
+
# Simulator config for reconstruction
|
| 913 |
+
"_base_arrival_rate": self._base_arrival_rate,
|
| 914 |
+
"_breakdown_prob": self._breakdown_prob,
|
| 915 |
+
"_batch_arrival_size": self._batch_arrival_size,
|
| 916 |
+
"_lunch_penalty_factor": self._lunch_penalty_factor,
|
| 917 |
+
"_job_type_frequencies": dict(self._job_type_frequencies),
|
| 918 |
+
"_due_date_tightness": self._due_date_tightness,
|
| 919 |
+
"_processing_time_scale": self._processing_time_scale,
|
| 920 |
+
# Serialized job data (can't deepcopy — SimPy events aren't picklable)
|
| 921 |
+
"all_jobs": {
|
| 922 |
+
jid: self._serialize_job(job)
|
| 923 |
+
for jid, job in self.all_jobs.items()
|
| 924 |
+
},
|
| 925 |
+
"completed_jobs": [self._serialize_job(j) for j in self.completed_jobs],
|
| 926 |
+
"zone_queues": {z: [j.job_id for j in q] for z, q in self.zone_queues.items()},
|
| 927 |
+
# Station states
|
| 928 |
+
"stations": {
|
| 929 |
+
sid: {
|
| 930 |
+
"station_id": st.station_id,
|
| 931 |
+
"zone_id": st.zone_id,
|
| 932 |
+
"is_broken": st.is_broken,
|
| 933 |
+
"repair_end_time": st.repair_end_time,
|
| 934 |
+
"current_job": st.current_job,
|
| 935 |
+
"busy_until": st.busy_until,
|
| 936 |
+
}
|
| 937 |
+
for sid, st in self.stations.items()
|
| 938 |
+
},
|
| 939 |
+
}
|
| 940 |
+
return state
|
| 941 |
+
|
| 942 |
+
@classmethod
|
| 943 |
+
def from_state(
|
| 944 |
+
cls,
|
| 945 |
+
state_dict: Dict[str, Any],
|
| 946 |
+
heuristic_fn: Callable,
|
| 947 |
+
) -> "WarehouseSimulator":
|
| 948 |
+
"""Create a new simulator from a saved state (for fork evaluation).
|
| 949 |
+
|
| 950 |
+
Creates a fresh SimPy environment initialized at saved_time,
|
| 951 |
+
restores all job/station/queue data, and continues RNG from saved state.
|
| 952 |
+
|
| 953 |
+
Parameters
|
| 954 |
+
----------
|
| 955 |
+
state_dict : dict
|
| 956 |
+
Output of save_state().
|
| 957 |
+
heuristic_fn : Callable
|
| 958 |
+
Dispatch function to use in the forked simulation.
|
| 959 |
+
|
| 960 |
+
Returns
|
| 961 |
+
-------
|
| 962 |
+
WarehouseSimulator
|
| 963 |
+
Ready to run from state_dict["env_time"] forward.
|
| 964 |
+
"""
|
| 965 |
+
saved_time = state_dict["env_time"]
|
| 966 |
+
|
| 967 |
+
# Reconstruct simulator with original config
|
| 968 |
+
sim = cls(
|
| 969 |
+
seed=state_dict["seed"],
|
| 970 |
+
heuristic_fn=heuristic_fn,
|
| 971 |
+
base_arrival_rate=state_dict["_base_arrival_rate"],
|
| 972 |
+
breakdown_prob=state_dict["_breakdown_prob"],
|
| 973 |
+
batch_arrival_size=state_dict["_batch_arrival_size"],
|
| 974 |
+
lunch_penalty_factor=state_dict["_lunch_penalty_factor"],
|
| 975 |
+
job_type_frequencies=state_dict["_job_type_frequencies"],
|
| 976 |
+
due_date_tightness=state_dict["_due_date_tightness"],
|
| 977 |
+
processing_time_scale=state_dict["_processing_time_scale"],
|
| 978 |
+
)
|
| 979 |
+
|
| 980 |
+
# Restore RNG from saved state (deterministic continuation)
|
| 981 |
+
sim.rng.bit_generator.state = state_dict["rng_state"]
|
| 982 |
+
|
| 983 |
+
# Restore job counter and metrics
|
| 984 |
+
sim._job_counter = state_dict["_job_counter"]
|
| 985 |
+
sim._max_queue = state_dict["_max_queue"]
|
| 986 |
+
sim._lunch_active = state_dict["_lunch_active"]
|
| 987 |
+
sim._zone_busy_time = dict(state_dict["_zone_busy_time"])
|
| 988 |
+
sim._queue_snapshots = list(state_dict["_queue_snapshots"])
|
| 989 |
+
|
| 990 |
+
# Restore jobs from serialized dicts
|
| 991 |
+
sim.all_jobs = {
|
| 992 |
+
jid: cls._deserialize_job(jdata)
|
| 993 |
+
for jid, jdata in state_dict["all_jobs"].items()
|
| 994 |
+
}
|
| 995 |
+
sim.completed_jobs = [
|
| 996 |
+
cls._deserialize_job(jdata)
|
| 997 |
+
for jdata in state_dict["completed_jobs"]
|
| 998 |
+
]
|
| 999 |
+
|
| 1000 |
+
# Restore zone queues (using saved job IDs to reference restored jobs)
|
| 1001 |
+
job_by_id = sim.all_jobs
|
| 1002 |
+
for z, queue_job_ids in state_dict["zone_queues"].items():
|
| 1003 |
+
sim.zone_queues[int(z)] = [
|
| 1004 |
+
job_by_id[jid] for jid in queue_job_ids
|
| 1005 |
+
if jid in job_by_id
|
| 1006 |
+
]
|
| 1007 |
+
|
| 1008 |
+
# Restore station states
|
| 1009 |
+
for sid_str, st_data in state_dict["stations"].items():
|
| 1010 |
+
sid = int(sid_str)
|
| 1011 |
+
if sid in sim.stations:
|
| 1012 |
+
sim.stations[sid].is_broken = st_data["is_broken"]
|
| 1013 |
+
sim.stations[sid].repair_end_time = st_data["repair_end_time"]
|
| 1014 |
+
sim.stations[sid].current_job = st_data["current_job"]
|
| 1015 |
+
sim.stations[sid].busy_until = st_data["busy_until"]
|
| 1016 |
+
|
| 1017 |
+
# Create a SimPy environment starting at saved_time
|
| 1018 |
+
sim.env = simpy.Environment(initial_time=saved_time)
|
| 1019 |
+
|
| 1020 |
+
# Re-create SimPy resources for the new environment
|
| 1021 |
+
for sid in sim.stations:
|
| 1022 |
+
sim.station_resources[sid] = simpy.Resource(sim.env, capacity=1)
|
| 1023 |
+
|
| 1024 |
+
# Re-create dispatcher trigger events for new environment
|
| 1025 |
+
for zone_id in sim.zones:
|
| 1026 |
+
sim.dispatcher_triggers[zone_id] = sim.env.event()
|
| 1027 |
+
|
| 1028 |
+
# Re-register dispatchers and breakdown/arrival processes
|
| 1029 |
+
sim.env.process(sim._arrival_process())
|
| 1030 |
+
sim.env.process(sim._batch_arrival_process())
|
| 1031 |
+
sim.env.process(sim._priority_escalation_process())
|
| 1032 |
+
|
| 1033 |
+
# Re-register lunch process correctly based on saved time
|
| 1034 |
+
if saved_time < 300.0:
|
| 1035 |
+
sim.env.process(sim._lunch_break_process())
|
| 1036 |
+
elif saved_time < 360.0:
|
| 1037 |
+
# Currently in lunch — restore the remaining lunch period
|
| 1038 |
+
remaining_lunch = 360.0 - saved_time
|
| 1039 |
+
|
| 1040 |
+
def _remaining_lunch():
|
| 1041 |
+
yield sim.env.timeout(remaining_lunch)
|
| 1042 |
+
sim._lunch_active = False
|
| 1043 |
+
|
| 1044 |
+
sim.env.process(_remaining_lunch())
|
| 1045 |
+
|
| 1046 |
+
sim.env.process(sim._snapshot_process())
|
| 1047 |
+
|
| 1048 |
+
for zone_id in sim.zones:
|
| 1049 |
+
sim.env.process(sim._zone_dispatcher(zone_id))
|
| 1050 |
+
|
| 1051 |
+
for station in sim.stations.values():
|
| 1052 |
+
if station.is_broken:
|
| 1053 |
+
remaining_repair = max(0.1, station.repair_end_time - saved_time)
|
| 1054 |
+
|
| 1055 |
+
def _resume_repair(st=station, t=remaining_repair):
|
| 1056 |
+
yield sim.env.timeout(t)
|
| 1057 |
+
st.is_broken = False
|
| 1058 |
+
sim._trigger_dispatcher(st.zone_id)
|
| 1059 |
+
# Continue with future breakdowns
|
| 1060 |
+
while True:
|
| 1061 |
+
ttf = float(sim.rng.exponential(1.0 / max(sim._breakdown_prob, 1e-9)))
|
| 1062 |
+
yield sim.env.timeout(ttf)
|
| 1063 |
+
st.is_broken = True
|
| 1064 |
+
repair_time = float(sim.rng.exponential(18.0))
|
| 1065 |
+
st.repair_end_time = sim.env.now + repair_time
|
| 1066 |
+
yield sim.env.timeout(repair_time)
|
| 1067 |
+
st.is_broken = False
|
| 1068 |
+
sim._trigger_dispatcher(st.zone_id)
|
| 1069 |
+
|
| 1070 |
+
sim.env.process(_resume_repair())
|
| 1071 |
+
else:
|
| 1072 |
+
sim.env.process(sim._station_breakdown_process(station))
|
| 1073 |
+
|
| 1074 |
+
# Resume WAITING jobs in zone queues:
|
| 1075 |
+
# These need a full _process_job-like coroutine that waits for dispatch
|
| 1076 |
+
# then routes through remaining operations.
|
| 1077 |
+
for zone_id, queue in sim.zone_queues.items():
|
| 1078 |
+
for job in queue:
|
| 1079 |
+
job._dispatch_event = sim.env.event()
|
| 1080 |
+
sim.env.process(sim._resume_waiting_job(job, zone_id))
|
| 1081 |
+
if queue:
|
| 1082 |
+
sim._trigger_dispatcher(zone_id)
|
| 1083 |
+
|
| 1084 |
+
# Resume PROCESSING jobs with correct remaining time:
|
| 1085 |
+
# At save time, op.start_time and op.actual_proc_time are set,
|
| 1086 |
+
# but op.end_time is still -1.0 (only set after timeout completes).
|
| 1087 |
+
# Remaining = (start_time + actual_proc_time) - saved_time
|
| 1088 |
+
for job in sim.all_jobs.values():
|
| 1089 |
+
if job.status == "processing" and job.current_op_idx < len(job.operations):
|
| 1090 |
+
op = job.operations[job.current_op_idx]
|
| 1091 |
+
if op.start_time >= 0 and op.actual_proc_time > 0:
|
| 1092 |
+
expected_end = op.start_time + op.actual_proc_time
|
| 1093 |
+
remaining = max(0.0, expected_end - saved_time)
|
| 1094 |
+
else:
|
| 1095 |
+
remaining = 0.0
|
| 1096 |
+
sim.env.process(sim._resume_job(job, remaining))
|
| 1097 |
+
|
| 1098 |
+
return sim
|
| 1099 |
+
|
| 1100 |
+
def _resume_job(self, job: Job, remaining_time: float):
|
| 1101 |
+
"""Continue processing a job that was in-progress at save_state time."""
|
| 1102 |
+
op_idx = job.current_op_idx
|
| 1103 |
+
op = job.operations[op_idx]
|
| 1104 |
+
|
| 1105 |
+
yield self.env.timeout(remaining_time)
|
| 1106 |
+
op.end_time = self.env.now
|
| 1107 |
+
|
| 1108 |
+
# Continue with remaining operations
|
| 1109 |
+
for next_op_idx in range(op_idx + 1, len(job.operations)):
|
| 1110 |
+
next_op = job.operations[next_op_idx]
|
| 1111 |
+
zone_id = next_op.zone_id
|
| 1112 |
+
|
| 1113 |
+
self.zone_queues[zone_id].append(job)
|
| 1114 |
+
job.status = "waiting"
|
| 1115 |
+
job._dispatch_event = self.env.event()
|
| 1116 |
+
self._trigger_dispatcher(zone_id)
|
| 1117 |
+
yield job._dispatch_event
|
| 1118 |
+
|
| 1119 |
+
station_id = self._pick_station(zone_id)
|
| 1120 |
+
next_op.station_id = station_id
|
| 1121 |
+
resource = self.station_resources[station_id]
|
| 1122 |
+
st = self.stations[station_id]
|
| 1123 |
+
st.current_job = job.job_id
|
| 1124 |
+
|
| 1125 |
+
with resource.request() as req:
|
| 1126 |
+
yield req
|
| 1127 |
+
while st.is_broken:
|
| 1128 |
+
wait_time = max(0.1, st.repair_end_time - self.env.now)
|
| 1129 |
+
yield self.env.timeout(wait_time)
|
| 1130 |
+
|
| 1131 |
+
job.status = "processing"
|
| 1132 |
+
job.current_op_idx = next_op_idx
|
| 1133 |
+
|
| 1134 |
+
variability = float(self.rng.lognormal(0, 0.30))
|
| 1135 |
+
lunch_penalty = self._lunch_penalty_factor if self._lunch_active else 1.0
|
| 1136 |
+
actual_time = next_op.nominal_proc_time * variability * lunch_penalty
|
| 1137 |
+
|
| 1138 |
+
next_op.actual_proc_time = actual_time
|
| 1139 |
+
next_op.start_time = self.env.now
|
| 1140 |
+
self._zone_busy_time[zone_id] = self._zone_busy_time.get(zone_id, 0.0) + actual_time
|
| 1141 |
+
|
| 1142 |
+
yield self.env.timeout(actual_time)
|
| 1143 |
+
|
| 1144 |
+
next_op.end_time = self.env.now
|
| 1145 |
+
st.busy_until = self.env.now
|
| 1146 |
+
st.current_job = None
|
| 1147 |
+
|
| 1148 |
+
self._trigger_dispatcher(zone_id)
|
| 1149 |
+
|
| 1150 |
+
job.status = "done"
|
| 1151 |
+
job.completion_time = self.env.now
|
| 1152 |
+
job.current_op_idx = len(job.operations)
|
| 1153 |
+
self.completed_jobs.append(job)
|
| 1154 |
+
|
| 1155 |
+
def _resume_waiting_job(self, job: Job, current_zone_id: int):
|
| 1156 |
+
"""Resume a job that was waiting in a zone queue at save_state time.
|
| 1157 |
+
|
| 1158 |
+
This replaces the missing _process_job coroutine for waiting jobs
|
| 1159 |
+
restored via from_state(). The job waits for dispatch in its current
|
| 1160 |
+
zone, processes that operation, then routes through all remaining ops.
|
| 1161 |
+
"""
|
| 1162 |
+
# Wait for dispatcher to select this job in the current zone
|
| 1163 |
+
yield job._dispatch_event
|
| 1164 |
+
|
| 1165 |
+
# Process the current operation (the one the job was waiting for)
|
| 1166 |
+
op_idx = job.current_op_idx
|
| 1167 |
+
op = job.operations[op_idx]
|
| 1168 |
+
zone_id = current_zone_id
|
| 1169 |
+
|
| 1170 |
+
station_id = self._pick_station(zone_id)
|
| 1171 |
+
op.station_id = station_id
|
| 1172 |
+
resource = self.station_resources[station_id]
|
| 1173 |
+
st = self.stations[station_id]
|
| 1174 |
+
st.current_job = job.job_id
|
| 1175 |
+
|
| 1176 |
+
with resource.request() as req:
|
| 1177 |
+
yield req
|
| 1178 |
+
while st.is_broken:
|
| 1179 |
+
wait_time = max(0.1, st.repair_end_time - self.env.now)
|
| 1180 |
+
yield self.env.timeout(wait_time)
|
| 1181 |
+
|
| 1182 |
+
job.status = "processing"
|
| 1183 |
+
job.current_op_idx = op_idx
|
| 1184 |
+
|
| 1185 |
+
variability = float(self.rng.lognormal(0, 0.30))
|
| 1186 |
+
lunch_penalty = self._lunch_penalty_factor if self._lunch_active else 1.0
|
| 1187 |
+
actual_time = op.nominal_proc_time * variability * lunch_penalty
|
| 1188 |
+
|
| 1189 |
+
op.actual_proc_time = actual_time
|
| 1190 |
+
op.start_time = self.env.now
|
| 1191 |
+
self._zone_busy_time[zone_id] = self._zone_busy_time.get(zone_id, 0.0) + actual_time
|
| 1192 |
+
|
| 1193 |
+
yield self.env.timeout(actual_time)
|
| 1194 |
+
|
| 1195 |
+
op.end_time = self.env.now
|
| 1196 |
+
st.busy_until = self.env.now
|
| 1197 |
+
st.current_job = None
|
| 1198 |
+
|
| 1199 |
+
self._trigger_dispatcher(zone_id)
|
| 1200 |
+
|
| 1201 |
+
# Continue with remaining operations (same as _resume_job)
|
| 1202 |
+
for next_op_idx in range(op_idx + 1, len(job.operations)):
|
| 1203 |
+
next_op = job.operations[next_op_idx]
|
| 1204 |
+
next_zone_id = next_op.zone_id
|
| 1205 |
+
|
| 1206 |
+
self.zone_queues[next_zone_id].append(job)
|
| 1207 |
+
job.status = "waiting"
|
| 1208 |
+
job._dispatch_event = self.env.event()
|
| 1209 |
+
self._trigger_dispatcher(next_zone_id)
|
| 1210 |
+
yield job._dispatch_event
|
| 1211 |
+
|
| 1212 |
+
station_id = self._pick_station(next_zone_id)
|
| 1213 |
+
next_op.station_id = station_id
|
| 1214 |
+
resource = self.station_resources[station_id]
|
| 1215 |
+
st = self.stations[station_id]
|
| 1216 |
+
st.current_job = job.job_id
|
| 1217 |
+
|
| 1218 |
+
with resource.request() as req:
|
| 1219 |
+
yield req
|
| 1220 |
+
while st.is_broken:
|
| 1221 |
+
wait_time = max(0.1, st.repair_end_time - self.env.now)
|
| 1222 |
+
yield self.env.timeout(wait_time)
|
| 1223 |
+
|
| 1224 |
+
job.status = "processing"
|
| 1225 |
+
job.current_op_idx = next_op_idx
|
| 1226 |
+
|
| 1227 |
+
variability = float(self.rng.lognormal(0, 0.30))
|
| 1228 |
+
lunch_penalty = self._lunch_penalty_factor if self._lunch_active else 1.0
|
| 1229 |
+
actual_time = next_op.nominal_proc_time * variability * lunch_penalty
|
| 1230 |
+
|
| 1231 |
+
next_op.actual_proc_time = actual_time
|
| 1232 |
+
next_op.start_time = self.env.now
|
| 1233 |
+
self._zone_busy_time[next_zone_id] = self._zone_busy_time.get(next_zone_id, 0.0) + actual_time
|
| 1234 |
+
|
| 1235 |
+
yield self.env.timeout(actual_time)
|
| 1236 |
+
|
| 1237 |
+
next_op.end_time = self.env.now
|
| 1238 |
+
st.busy_until = self.env.now
|
| 1239 |
+
st.current_job = None
|
| 1240 |
+
|
| 1241 |
+
self._trigger_dispatcher(next_zone_id)
|
| 1242 |
+
|
| 1243 |
+
job.status = "done"
|
| 1244 |
+
job.completion_time = self.env.now
|
| 1245 |
+
job.current_op_idx = len(job.operations)
|
| 1246 |
+
self.completed_jobs.append(job)
|
| 1247 |
+
|
| 1248 |
+
# ------------------------------------------------------------------
|
| 1249 |
+
# NEW in DAHS_2: Partial metrics for fork evaluation windows
|
| 1250 |
+
# ------------------------------------------------------------------
|
| 1251 |
+
|
| 1252 |
+
def get_partial_metrics(self, since_time: float) -> SimulationMetrics:
|
| 1253 |
+
"""Compute metrics only for jobs completed between since_time and env.now.
|
| 1254 |
+
|
| 1255 |
+
Used in the 20-minute fork evaluation window during data generation.
|
| 1256 |
+
|
| 1257 |
+
Parameters
|
| 1258 |
+
----------
|
| 1259 |
+
since_time : float
|
| 1260 |
+
Start of evaluation window (simulation time).
|
| 1261 |
+
|
| 1262 |
+
Returns
|
| 1263 |
+
-------
|
| 1264 |
+
SimulationMetrics
|
| 1265 |
+
Metrics computed only over jobs completed in [since_time, now].
|
| 1266 |
+
"""
|
| 1267 |
+
now = self.env.now
|
| 1268 |
+
window_jobs = [
|
| 1269 |
+
j for j in self.completed_jobs
|
| 1270 |
+
if j.completion_time >= since_time
|
| 1271 |
+
]
|
| 1272 |
+
|
| 1273 |
+
if not window_jobs:
|
| 1274 |
+
return SimulationMetrics(
|
| 1275 |
+
makespan=now,
|
| 1276 |
+
zone_utilization={z: 0.0 for z in self.zones},
|
| 1277 |
+
)
|
| 1278 |
+
|
| 1279 |
+
n = len(window_jobs)
|
| 1280 |
+
total_tardiness = sum(max(0.0, j.completion_time - j.due_date) for j in window_jobs)
|
| 1281 |
+
n_late = sum(1 for j in window_jobs if j.completion_time > j.due_date)
|
| 1282 |
+
sla_breach_rate = n_late / n
|
| 1283 |
+
avg_cycle_time = float(np.mean([j.completion_time - j.arrival_time for j in window_jobs]))
|
| 1284 |
+
duration = max(now - since_time, 1.0)
|
| 1285 |
+
throughput = (n / duration) * 60.0
|
| 1286 |
+
|
| 1287 |
+
zone_utilization = {
|
| 1288 |
+
z: min(1.0, self._zone_busy_time.get(z, 0.0) / max(1.0, now * self.zones[z].num_stations))
|
| 1289 |
+
for z in self.zones
|
| 1290 |
+
}
|
| 1291 |
+
|
| 1292 |
+
return SimulationMetrics(
|
| 1293 |
+
makespan=max(j.completion_time for j in window_jobs),
|
| 1294 |
+
total_tardiness=total_tardiness,
|
| 1295 |
+
sla_breach_rate=sla_breach_rate,
|
| 1296 |
+
avg_cycle_time=avg_cycle_time,
|
| 1297 |
+
zone_utilization=zone_utilization,
|
| 1298 |
+
throughput=throughput,
|
| 1299 |
+
queue_max=self._max_queue,
|
| 1300 |
+
completed_jobs=n,
|
| 1301 |
+
total_jobs=len(self.all_jobs),
|
| 1302 |
+
)
|
src/train_priority.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
train_priority.py — Train GBR Priority Predictor (port from DAHS_1)
|
| 3 |
+
|
| 4 |
+
Trains a GradientBoostingRegressor on the priority dataset to predict
|
| 5 |
+
a continuous job priority score used by the Hybrid-Priority scheduler.
|
| 6 |
+
|
| 7 |
+
Outputs:
|
| 8 |
+
- models/priority_gbr.joblib
|
| 9 |
+
- results/plots/shap_summary.png
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import logging
|
| 15 |
+
import warnings
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
import joblib
|
| 19 |
+
import matplotlib
|
| 20 |
+
matplotlib.use("Agg")
|
| 21 |
+
import matplotlib.pyplot as plt
|
| 22 |
+
import numpy as np
|
| 23 |
+
import pandas as pd
|
| 24 |
+
import shap
|
| 25 |
+
from sklearn.ensemble import GradientBoostingRegressor
|
| 26 |
+
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
|
| 27 |
+
from sklearn.model_selection import KFold, cross_val_score, train_test_split
|
| 28 |
+
|
| 29 |
+
warnings.filterwarnings("ignore")
|
| 30 |
+
|
| 31 |
+
logger = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
+
DATA_PATH = Path(__file__).parent.parent / "data" / "raw" / "priority_dataset.csv"
|
| 34 |
+
MODELS_DIR = Path(__file__).parent.parent / "models"
|
| 35 |
+
PLOTS_DIR = Path(__file__).parent.parent / "results" / "plots"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def train_priority_model(data_path: Path = DATA_PATH) -> GradientBoostingRegressor:
|
| 39 |
+
"""Train and evaluate the GBR priority predictor.
|
| 40 |
+
|
| 41 |
+
Returns
|
| 42 |
+
-------
|
| 43 |
+
GradientBoostingRegressor
|
| 44 |
+
Fitted model.
|
| 45 |
+
"""
|
| 46 |
+
MODELS_DIR.mkdir(parents=True, exist_ok=True)
|
| 47 |
+
PLOTS_DIR.mkdir(parents=True, exist_ok=True)
|
| 48 |
+
|
| 49 |
+
logger.info("Loading priority dataset from %s", data_path)
|
| 50 |
+
df = pd.read_csv(data_path)
|
| 51 |
+
# Bug fix from DAHS_1: use replace + dropna (not nan_to_num alone)
|
| 52 |
+
df = df.replace([np.inf, -np.inf], np.nan).dropna()
|
| 53 |
+
|
| 54 |
+
feature_cols = [c for c in df.columns if c != "priority_score"]
|
| 55 |
+
X = df[feature_cols].values.astype(np.float32)
|
| 56 |
+
y = df["priority_score"].values.astype(np.float32)
|
| 57 |
+
|
| 58 |
+
logger.info("Priority dataset shape: X=%s, y=%s", X.shape, y.shape)
|
| 59 |
+
|
| 60 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
| 61 |
+
X, y, test_size=0.20, random_state=42
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
model = GradientBoostingRegressor(
|
| 65 |
+
n_estimators=300,
|
| 66 |
+
max_depth=6,
|
| 67 |
+
learning_rate=0.05,
|
| 68 |
+
subsample=0.8,
|
| 69 |
+
min_samples_leaf=5,
|
| 70 |
+
random_state=42,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
logger.info("Training GradientBoostingRegressor ...")
|
| 74 |
+
model.fit(X_train, y_train)
|
| 75 |
+
|
| 76 |
+
y_pred = model.predict(X_test)
|
| 77 |
+
r2 = r2_score(y_test, y_pred)
|
| 78 |
+
mae = mean_absolute_error(y_test, y_pred)
|
| 79 |
+
rmse = float(np.sqrt(mean_squared_error(y_test, y_pred)))
|
| 80 |
+
|
| 81 |
+
print(f"[GBR] Test R^2: {r2:.4f}")
|
| 82 |
+
print(f"[GBR] Test MAE: {mae:.4f}")
|
| 83 |
+
print(f"[GBR] Test RMSE: {rmse:.4f}")
|
| 84 |
+
logger.info("GBR Test -> R^2=%.4f MAE=%.4f RMSE=%.4f", r2, mae, rmse)
|
| 85 |
+
|
| 86 |
+
cv = KFold(n_splits=5, shuffle=True, random_state=42)
|
| 87 |
+
cv_scores = cross_val_score(model, X_train, y_train, cv=cv, scoring="r2", n_jobs=-1)
|
| 88 |
+
print(f"[GBR] 5-Fold CV R^2: {cv_scores.mean():.4f} +/- {cv_scores.std():.4f}")
|
| 89 |
+
logger.info("GBR CV R^2: %.4f +/- %.4f", cv_scores.mean(), cv_scores.std())
|
| 90 |
+
|
| 91 |
+
model_path = MODELS_DIR / "priority_gbr.joblib"
|
| 92 |
+
joblib.dump(model, model_path)
|
| 93 |
+
logger.info("Saved model -> %s", model_path)
|
| 94 |
+
|
| 95 |
+
_generate_shap_plot(model, X_test, feature_cols)
|
| 96 |
+
|
| 97 |
+
return model
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _generate_shap_plot(
|
| 101 |
+
model: GradientBoostingRegressor,
|
| 102 |
+
X_sample: np.ndarray,
|
| 103 |
+
feature_names: list,
|
| 104 |
+
) -> None:
|
| 105 |
+
"""Generate and save SHAP beeswarm summary plot."""
|
| 106 |
+
logger.info("Computing SHAP values ...")
|
| 107 |
+
sample_size = min(500, X_sample.shape[0])
|
| 108 |
+
X_shap = X_sample[:sample_size]
|
| 109 |
+
|
| 110 |
+
explainer = shap.TreeExplainer(model)
|
| 111 |
+
shap_values = explainer.shap_values(X_shap)
|
| 112 |
+
|
| 113 |
+
fig, ax = plt.subplots(figsize=(10, 8))
|
| 114 |
+
fig.patch.set_facecolor("#0f1117")
|
| 115 |
+
ax.set_facecolor("#1a1d27")
|
| 116 |
+
|
| 117 |
+
shap.summary_plot(
|
| 118 |
+
shap_values,
|
| 119 |
+
X_shap,
|
| 120 |
+
feature_names=feature_names,
|
| 121 |
+
show=False,
|
| 122 |
+
plot_type="dot",
|
| 123 |
+
color_bar=True,
|
| 124 |
+
max_display=18,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
plt.gcf().set_facecolor("#0f1117")
|
| 128 |
+
plt.title("Priority GBR — SHAP Feature Importance", color="white", fontsize=14, pad=12)
|
| 129 |
+
plt.tight_layout()
|
| 130 |
+
|
| 131 |
+
shap_path = PLOTS_DIR / "shap_summary.png"
|
| 132 |
+
plt.savefig(shap_path, dpi=150, bbox_inches="tight", facecolor="#0f1117")
|
| 133 |
+
plt.close()
|
| 134 |
+
logger.info("Saved SHAP plot -> %s", shap_path)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
if __name__ == "__main__":
|
| 138 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
| 139 |
+
train_priority_model()
|
src/train_selector.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
train_selector.py — Train Heuristic Selector Models (DAHS_2)
|
| 3 |
+
|
| 4 |
+
Trains three classifiers (Decision Tree, Random Forest, XGBoost) to predict
|
| 5 |
+
which of 6 heuristics achieves the best dispatching outcome for a given
|
| 6 |
+
system state (snapshot-fork labels).
|
| 7 |
+
|
| 8 |
+
NEW in DAHS_2:
|
| 9 |
+
- Exports models/feature_ranges.json
|
| 10 |
+
- Exports models/dt_structure.json (for frontend glass-box)
|
| 11 |
+
- Exports models/feature_names.json
|
| 12 |
+
|
| 13 |
+
Outputs:
|
| 14 |
+
- models/selector_dt.joblib
|
| 15 |
+
- models/selector_rf.joblib
|
| 16 |
+
- models/selector_xgb.joblib
|
| 17 |
+
- models/feature_ranges.json
|
| 18 |
+
- models/dt_structure.json
|
| 19 |
+
- models/feature_names.json
|
| 20 |
+
- results/plots/feature_importance.png
|
| 21 |
+
- results/plots/decision_tree.png
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
import hashlib
|
| 27 |
+
import json
|
| 28 |
+
import logging
|
| 29 |
+
import time
|
| 30 |
+
import warnings
|
| 31 |
+
from pathlib import Path
|
| 32 |
+
from typing import Any, Dict, List
|
| 33 |
+
|
| 34 |
+
import joblib
|
| 35 |
+
import matplotlib
|
| 36 |
+
matplotlib.use("Agg")
|
| 37 |
+
import matplotlib.pyplot as plt
|
| 38 |
+
import numpy as np
|
| 39 |
+
import pandas as pd
|
| 40 |
+
from sklearn.ensemble import RandomForestClassifier
|
| 41 |
+
from sklearn.metrics import classification_report
|
| 42 |
+
from sklearn.model_selection import StratifiedKFold, cross_val_score, train_test_split
|
| 43 |
+
from sklearn.tree import DecisionTreeClassifier, plot_tree
|
| 44 |
+
from xgboost import XGBClassifier
|
| 45 |
+
|
| 46 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 47 |
+
|
| 48 |
+
logger = logging.getLogger(__name__)
|
| 49 |
+
|
| 50 |
+
DATA_PATH = Path(__file__).parent.parent / "data" / "raw" / "selector_dataset.csv"
|
| 51 |
+
MODELS_DIR = Path(__file__).parent.parent / "models"
|
| 52 |
+
PLOTS_DIR = Path(__file__).parent.parent / "results" / "plots"
|
| 53 |
+
|
| 54 |
+
LABEL_NAMES = ["FIFO", "Priority-EDD", "Critical-Ratio", "ATC", "WSPT", "Slack"]
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _extract_dt_structure(dt: DecisionTreeClassifier, feature_names: List[str]) -> Dict[str, Any]:
|
| 58 |
+
"""Extract decision tree node structure for frontend glass-box visualization.
|
| 59 |
+
|
| 60 |
+
Returns a dict with nodes list, each node having:
|
| 61 |
+
{id, feature, threshold, left, right, class, samples, impurity}
|
| 62 |
+
"""
|
| 63 |
+
tree = dt.tree_
|
| 64 |
+
nodes = []
|
| 65 |
+
|
| 66 |
+
def _recurse(node_id: int) -> None:
|
| 67 |
+
feature_idx = int(tree.feature[node_id])
|
| 68 |
+
threshold = float(tree.threshold[node_id])
|
| 69 |
+
left_child = int(tree.children_left[node_id])
|
| 70 |
+
right_child = int(tree.children_right[node_id])
|
| 71 |
+
values = tree.value[node_id][0]
|
| 72 |
+
dominant = int(np.argmax(values))
|
| 73 |
+
samples = int(tree.n_node_samples[node_id])
|
| 74 |
+
impurity = float(tree.impurity[node_id])
|
| 75 |
+
|
| 76 |
+
node: Dict[str, Any] = {
|
| 77 |
+
"id": node_id,
|
| 78 |
+
"samples": samples,
|
| 79 |
+
"impurity": round(impurity, 4),
|
| 80 |
+
"class": LABEL_NAMES[dominant],
|
| 81 |
+
"classIdx": dominant,
|
| 82 |
+
"values": [int(v) for v in values],
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
if left_child != -1: # not a leaf
|
| 86 |
+
feat_name = feature_names[feature_idx] if feature_idx < len(feature_names) else f"f{feature_idx}"
|
| 87 |
+
node["feature"] = feat_name
|
| 88 |
+
node["featureIdx"] = feature_idx
|
| 89 |
+
node["threshold"] = round(threshold, 4)
|
| 90 |
+
node["left"] = left_child
|
| 91 |
+
node["right"] = right_child
|
| 92 |
+
_recurse(left_child)
|
| 93 |
+
_recurse(right_child)
|
| 94 |
+
|
| 95 |
+
nodes.append(node)
|
| 96 |
+
|
| 97 |
+
_recurse(0)
|
| 98 |
+
return {"nodes": nodes, "featureNames": feature_names, "classNames": LABEL_NAMES}
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def train_selector_models(data_path: Path = DATA_PATH) -> dict:
|
| 102 |
+
"""Train all three selector classifiers and save artifacts.
|
| 103 |
+
|
| 104 |
+
Returns
|
| 105 |
+
-------
|
| 106 |
+
dict
|
| 107 |
+
Mapping model_name -> trained sklearn-compatible model.
|
| 108 |
+
"""
|
| 109 |
+
MODELS_DIR.mkdir(parents=True, exist_ok=True)
|
| 110 |
+
PLOTS_DIR.mkdir(parents=True, exist_ok=True)
|
| 111 |
+
|
| 112 |
+
logger.info("Loading selector dataset from %s", data_path)
|
| 113 |
+
df = pd.read_csv(data_path)
|
| 114 |
+
|
| 115 |
+
feature_cols = [c for c in df.columns if c != "label"]
|
| 116 |
+
X = df[feature_cols].values.astype(np.float32)
|
| 117 |
+
# Sanitize: NaN/inf safety (training pipeline bug fix from DAHS_1)
|
| 118 |
+
X = np.nan_to_num(X, nan=0.0, posinf=999.0, neginf=-999.0)
|
| 119 |
+
y = df["label"].values.astype(int)
|
| 120 |
+
|
| 121 |
+
logger.info("Dataset shape: X=%s, label distribution: %s",
|
| 122 |
+
X.shape, dict(zip(*np.unique(y, return_counts=True))))
|
| 123 |
+
|
| 124 |
+
# Training-run hash binds every artifact in this run together so the
|
| 125 |
+
# selector loader can detect a stale OOD ranges file or a feature-list
|
| 126 |
+
# mismatch loudly rather than silently shifting baseline-vs-DAHS results.
|
| 127 |
+
run_hash = hashlib.sha256(
|
| 128 |
+
f"{time.time()}|{X.shape}|{','.join(feature_cols)}|{int(y.sum())}".encode()
|
| 129 |
+
).hexdigest()[:16]
|
| 130 |
+
|
| 131 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
| 132 |
+
X, y, test_size=0.20, random_state=42, stratify=y
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# CV seed different from train/test split seed (bug fix)
|
| 136 |
+
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=123)
|
| 137 |
+
|
| 138 |
+
from sklearn.utils.class_weight import compute_sample_weight
|
| 139 |
+
sample_weights_train = compute_sample_weight("balanced", y_train)
|
| 140 |
+
|
| 141 |
+
models = {
|
| 142 |
+
"dt": DecisionTreeClassifier(
|
| 143 |
+
max_depth=10,
|
| 144 |
+
class_weight="balanced",
|
| 145 |
+
random_state=42,
|
| 146 |
+
),
|
| 147 |
+
"rf": RandomForestClassifier(
|
| 148 |
+
n_estimators=400,
|
| 149 |
+
max_depth=14,
|
| 150 |
+
class_weight="balanced",
|
| 151 |
+
n_jobs=-1,
|
| 152 |
+
random_state=42,
|
| 153 |
+
),
|
| 154 |
+
"xgb": XGBClassifier(
|
| 155 |
+
n_estimators=500,
|
| 156 |
+
learning_rate=0.03,
|
| 157 |
+
max_depth=8,
|
| 158 |
+
num_class=len(LABEL_NAMES),
|
| 159 |
+
n_jobs=-1,
|
| 160 |
+
random_state=42,
|
| 161 |
+
eval_metric="mlogloss",
|
| 162 |
+
verbosity=0,
|
| 163 |
+
),
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
trained = {}
|
| 167 |
+
|
| 168 |
+
for name, model in models.items():
|
| 169 |
+
logger.info("Training %s ...", name.upper())
|
| 170 |
+
if name == "xgb":
|
| 171 |
+
model.fit(X_train, y_train, sample_weight=sample_weights_train)
|
| 172 |
+
else:
|
| 173 |
+
model.fit(X_train, y_train)
|
| 174 |
+
|
| 175 |
+
# 5-fold CV accuracy
|
| 176 |
+
cv_scores = cross_val_score(model, X_train, y_train, cv=cv, scoring="accuracy", n_jobs=-1)
|
| 177 |
+
logger.info("[%s] CV accuracy: %.4f +/- %.4f", name.upper(), cv_scores.mean(), cv_scores.std())
|
| 178 |
+
print(f"[{name.upper()}] 5-Fold CV Accuracy: {cv_scores.mean():.4f} +/- {cv_scores.std():.4f}")
|
| 179 |
+
|
| 180 |
+
y_pred = model.predict(X_test)
|
| 181 |
+
print(f"\n[{name.upper()}] Classification Report (Test Set):")
|
| 182 |
+
print(classification_report(
|
| 183 |
+
y_test, y_pred,
|
| 184 |
+
labels=list(range(len(LABEL_NAMES))),
|
| 185 |
+
target_names=LABEL_NAMES,
|
| 186 |
+
zero_division=0,
|
| 187 |
+
))
|
| 188 |
+
|
| 189 |
+
model_path = MODELS_DIR / f"selector_{name}.joblib"
|
| 190 |
+
# Tag the estimator with the training-run hash so loaders can verify
|
| 191 |
+
# it matches the on-disk feature_ranges.json / feature_names.json.
|
| 192 |
+
try:
|
| 193 |
+
setattr(model, "_dahs_run_hash", run_hash)
|
| 194 |
+
except Exception:
|
| 195 |
+
pass
|
| 196 |
+
joblib.dump(model, model_path)
|
| 197 |
+
logger.info("Saved model -> %s", model_path)
|
| 198 |
+
|
| 199 |
+
trained[name] = model
|
| 200 |
+
|
| 201 |
+
# ------------------------------------------------------------------
|
| 202 |
+
# NEW in DAHS_2: Export interpretability artifacts
|
| 203 |
+
# ------------------------------------------------------------------
|
| 204 |
+
|
| 205 |
+
# 1. Feature ranges (for OOD detection in BatchwiseSelector)
|
| 206 |
+
feature_ranges = {}
|
| 207 |
+
for i, name in enumerate(feature_cols):
|
| 208 |
+
feature_ranges[name] = [float(X_train[:, i].min()), float(X_train[:, i].max())]
|
| 209 |
+
feature_ranges_payload = {
|
| 210 |
+
"_meta": {
|
| 211 |
+
"run_hash": run_hash,
|
| 212 |
+
"n_train": int(X_train.shape[0]),
|
| 213 |
+
"feature_count": len(feature_cols),
|
| 214 |
+
},
|
| 215 |
+
"ranges": feature_ranges,
|
| 216 |
+
}
|
| 217 |
+
with open(MODELS_DIR / "feature_ranges.json", "w") as f:
|
| 218 |
+
json.dump(feature_ranges_payload, f, indent=2)
|
| 219 |
+
logger.info("Saved feature_ranges.json -> %s", MODELS_DIR / "feature_ranges.json")
|
| 220 |
+
|
| 221 |
+
# 2. Feature names with descriptions
|
| 222 |
+
from src.features import FEATURE_DESCRIPTIONS
|
| 223 |
+
feature_names_data = [
|
| 224 |
+
{
|
| 225 |
+
"name": name,
|
| 226 |
+
"description": FEATURE_DESCRIPTIONS.get(name, name),
|
| 227 |
+
"category": (
|
| 228 |
+
"disruption" if name in ("disruption_intensity", "queue_imbalance", "job_mix_entropy", "time_pressure_ratio")
|
| 229 |
+
else "utilization" if "utilization" in name or "bottleneck" in name
|
| 230 |
+
else "timing" if "due" in name or "tard" in name or "sla" in name
|
| 231 |
+
else "queue" if "queue" in name or "throughput" in name
|
| 232 |
+
else "system"
|
| 233 |
+
),
|
| 234 |
+
"index": i,
|
| 235 |
+
}
|
| 236 |
+
for i, name in enumerate(feature_cols)
|
| 237 |
+
]
|
| 238 |
+
feature_names_payload = {
|
| 239 |
+
"_meta": {"run_hash": run_hash},
|
| 240 |
+
"features": feature_names_data,
|
| 241 |
+
}
|
| 242 |
+
with open(MODELS_DIR / "feature_names.json", "w") as f:
|
| 243 |
+
json.dump(feature_names_payload, f, indent=2)
|
| 244 |
+
logger.info("Saved feature_names.json -> %s", MODELS_DIR / "feature_names.json")
|
| 245 |
+
|
| 246 |
+
# 3. Decision tree structure (for frontend glass-box)
|
| 247 |
+
dt_structure = _extract_dt_structure(trained["dt"], feature_cols)
|
| 248 |
+
dt_structure["_meta"] = {"run_hash": run_hash}
|
| 249 |
+
with open(MODELS_DIR / "dt_structure.json", "w") as f:
|
| 250 |
+
json.dump(dt_structure, f, indent=2)
|
| 251 |
+
logger.info("Saved dt_structure.json -> %s", MODELS_DIR / "dt_structure.json")
|
| 252 |
+
|
| 253 |
+
# ------------------------------------------------------------------
|
| 254 |
+
# Feature importance plot (RF + XGB side-by-side, dark theme)
|
| 255 |
+
# ------------------------------------------------------------------
|
| 256 |
+
rf_importances = trained["rf"].feature_importances_
|
| 257 |
+
xgb_importances = trained["xgb"].feature_importances_
|
| 258 |
+
|
| 259 |
+
fig, axes = plt.subplots(1, 2, figsize=(16, 8))
|
| 260 |
+
fig.patch.set_facecolor("#0f1117")
|
| 261 |
+
|
| 262 |
+
for ax, importances, title, color in zip(
|
| 263 |
+
axes,
|
| 264 |
+
[rf_importances, xgb_importances],
|
| 265 |
+
["Random Forest Feature Importance", "XGBoost Feature Importance"],
|
| 266 |
+
["#4fc3f7", "#a5d6a7"],
|
| 267 |
+
):
|
| 268 |
+
ax.set_facecolor("#1a1d27")
|
| 269 |
+
sorted_idx = np.argsort(importances)[-15:]
|
| 270 |
+
ax.barh(
|
| 271 |
+
[feature_cols[i] for i in sorted_idx],
|
| 272 |
+
importances[sorted_idx],
|
| 273 |
+
color=color,
|
| 274 |
+
alpha=0.85,
|
| 275 |
+
)
|
| 276 |
+
ax.set_title(title, color="white", fontsize=13, pad=10)
|
| 277 |
+
ax.set_xlabel("Importance", color="#aaaaaa")
|
| 278 |
+
ax.tick_params(colors="#cccccc", labelsize=9)
|
| 279 |
+
for spine in ax.spines.values():
|
| 280 |
+
spine.set_color("#333344")
|
| 281 |
+
spine.set_linewidth(0.5)
|
| 282 |
+
|
| 283 |
+
fig.suptitle("Heuristic Selector — Feature Importances (DAHS_2)", color="white", fontsize=15, y=1.01)
|
| 284 |
+
plt.tight_layout()
|
| 285 |
+
fi_path = PLOTS_DIR / "feature_importance.png"
|
| 286 |
+
plt.savefig(fi_path, dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor())
|
| 287 |
+
plt.close()
|
| 288 |
+
logger.info("Saved feature importance plot -> %s", fi_path)
|
| 289 |
+
|
| 290 |
+
# ------------------------------------------------------------------
|
| 291 |
+
# Decision tree visualization
|
| 292 |
+
# ------------------------------------------------------------------
|
| 293 |
+
fig, ax = plt.subplots(figsize=(24, 10))
|
| 294 |
+
fig.patch.set_facecolor("#0f1117")
|
| 295 |
+
ax.set_facecolor("#0f1117")
|
| 296 |
+
plot_tree(
|
| 297 |
+
trained["dt"],
|
| 298 |
+
feature_names=feature_cols,
|
| 299 |
+
class_names=LABEL_NAMES,
|
| 300 |
+
filled=True,
|
| 301 |
+
max_depth=4,
|
| 302 |
+
fontsize=7,
|
| 303 |
+
ax=ax,
|
| 304 |
+
)
|
| 305 |
+
ax.set_title("Decision Tree Classifier (depth≤4 shown)", color="white", fontsize=14)
|
| 306 |
+
dt_path = PLOTS_DIR / "decision_tree.png"
|
| 307 |
+
plt.savefig(dt_path, dpi=120, bbox_inches="tight", facecolor=fig.get_facecolor())
|
| 308 |
+
plt.close()
|
| 309 |
+
logger.info("Saved decision tree plot -> %s", dt_path)
|
| 310 |
+
|
| 311 |
+
return trained
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
if __name__ == "__main__":
|
| 315 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
| 316 |
+
train_selector_models()
|