Spaces:
Sleeping
Sleeping
File size: 4,952 Bytes
906e104 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 | """Tests for the four BatchwiseSelector edge-case guardrails.
Guardrails:
- Trivial load (< TRIVIAL_LOAD jobs) β FIFO
- Overload (avg utilization > OVERLOAD_THRESHOLD) β ATC
- Out-of-distribution (>10% beyond range) β ATC
- Starvation (job waiting > STARVATION_LIMIT)β force-promote in dispatch
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import List
import numpy as np
from src.features import FeatureExtractor, SCENARIO_FEATURE_NAMES
from src.hybrid_scheduler import BatchwiseSelector
N_FEATS = len(SCENARIO_FEATURE_NAMES)
class _StubModel:
"""Always votes for WSPT (idx 4) β so any time we see ATC/FIFO selected,
a guardrail must have fired."""
def predict_proba(self, X):
proba = np.zeros((1, 6))
proba[0, 4] = 1.0
return proba
@dataclass
class _Op:
zone_id: int = 0
nominal_proc_time: float = 5.0
@dataclass
class _MiniJob:
job_id: int
job_type: str = "C"
arrival_time: float = 0.0
due_date: float = 100.0
operations: list = field(default_factory=lambda: [_Op()])
current_op_idx: int = 0
status: str = "waiting"
completion_time: float = -1.0
@property
def is_complete(self):
return False
def remaining_proc_time(self):
return 5.0
def _state(n_orders=50, util=0.5, n_broken=0):
waiting = [_MiniJob(job_id=i, due_date=100.0 + 50 * i) for i in range(min(n_orders, 4))]
return {
"current_time": 10.0,
"n_orders_in_system": n_orders,
"queue_sizes": {z: max(1, n_orders // 8) for z in range(8)},
"zone_utilization": {z: util for z in range(8)},
"n_broken_stations": n_broken,
"lunch_active": False,
"surge_multiplier": 1.0,
"completed_so_far": 0,
"waiting_jobs": waiting,
"completed_jobs": [],
"all_jobs": {j.job_id: j for j in waiting},
}
def _wide_ranges():
"""Permissive ranges so OOD does NOT fire on baseline state. Tests that
target OOD override specific entries."""
return {n: (-1e6, 1e6) for n in SCENARIO_FEATURE_NAMES}
def _selector_with_ranges(ranges=None):
fe = FeatureExtractor()
fe.set_feature_ranges(ranges if ranges is not None else _wide_ranges())
return BatchwiseSelector(
model=_StubModel(),
feature_extractor=fe,
feature_importances=np.ones(N_FEATS) / N_FEATS,
feature_names=list(SCENARIO_FEATURE_NAMES),
)
def test_trivial_load_forces_fifo():
sel = _selector_with_ranges()
sel.update_state(_state(n_orders=3, util=0.4))
sel._reevaluate(now=0.0)
assert sel._current_heuristic == "fifo"
last = sel.switching_log.entries[-1]
assert last["reason"].startswith("guardrail")
def test_overload_locks_to_atc():
sel = _selector_with_ranges()
sel.update_state(_state(n_orders=80, util=0.95))
sel._reevaluate(now=0.0)
assert sel._current_heuristic == "atc"
last = sel.switching_log.entries[-1]
assert last["reason"].startswith("guardrail")
def test_ood_falls_back_to_atc():
# Ranges where util is in [0.10, 0.70] but state has util=0.85 (>10% over).
ranges = _wide_ranges()
ranges["zone_utilization_avg"] = (0.10, 0.70)
sel = _selector_with_ranges(ranges)
sel.update_state(_state(n_orders=50, util=0.85))
sel._reevaluate(now=0.0)
# Trivial load false (50 > 5), overload false (<0.92), so OOD must fire.
assert sel._current_heuristic == "atc"
last = sel.switching_log.entries[-1]
assert last["reason"].startswith("guardrail")
def test_no_guardrail_uses_ml_choice():
sel = _selector_with_ranges()
sel.update_state(_state(n_orders=30, util=0.5))
sel._reevaluate(now=0.0)
assert sel._current_heuristic == "wspt" # the stub's argmax
last = sel.switching_log.entries[-1]
assert last["reason"] == "ml_decision"
# ---------------------------------------------------------------------------
# Starvation prevention
# ---------------------------------------------------------------------------
@dataclass
class _FakeJob:
job_id: int
arrival_time: float
job_type: str = "C"
due_date: float = 9999.0
operations: list = field(default_factory=list)
current_op_idx: int = 0
status: str = "waiting"
@property
def is_complete(self):
return False
def remaining_proc_time(self):
return 5.0
def test_starvation_promotes_old_jobs_to_front():
sel = _selector_with_ranges()
sel._current_heuristic = "fifo" # simple ordering for the assertion
now = 200.0
young = _FakeJob(job_id=1, arrival_time=now - 10.0)
old = _FakeJob(job_id=2, arrival_time=now - 90.0)
middle = _FakeJob(job_id=3, arrival_time=now - 30.0)
ordered = sel.dispatch([young, middle, old], current_time=now, zone_id=0)
assert ordered[0].job_id == 2
rest_ids = [j.job_id for j in ordered[1:]]
assert rest_ids == [3, 1]
|