Disruption-System / tests /test_guardrails.py
Vittal-M's picture
Upload 66 files
906e104 verified
"""Tests for the four BatchwiseSelector edge-case guardrails.
Guardrails:
- Trivial load (< TRIVIAL_LOAD jobs) → FIFO
- Overload (avg utilization > OVERLOAD_THRESHOLD) → ATC
- Out-of-distribution (>10% beyond range) → ATC
- Starvation (job waiting > STARVATION_LIMIT)→ force-promote in dispatch
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import List
import numpy as np
from src.features import FeatureExtractor, SCENARIO_FEATURE_NAMES
from src.hybrid_scheduler import BatchwiseSelector
N_FEATS = len(SCENARIO_FEATURE_NAMES)
class _StubModel:
"""Always votes for WSPT (idx 4) — so any time we see ATC/FIFO selected,
a guardrail must have fired."""
def predict_proba(self, X):
proba = np.zeros((1, 6))
proba[0, 4] = 1.0
return proba
@dataclass
class _Op:
zone_id: int = 0
nominal_proc_time: float = 5.0
@dataclass
class _MiniJob:
job_id: int
job_type: str = "C"
arrival_time: float = 0.0
due_date: float = 100.0
operations: list = field(default_factory=lambda: [_Op()])
current_op_idx: int = 0
status: str = "waiting"
completion_time: float = -1.0
@property
def is_complete(self):
return False
def remaining_proc_time(self):
return 5.0
def _state(n_orders=50, util=0.5, n_broken=0):
waiting = [_MiniJob(job_id=i, due_date=100.0 + 50 * i) for i in range(min(n_orders, 4))]
return {
"current_time": 10.0,
"n_orders_in_system": n_orders,
"queue_sizes": {z: max(1, n_orders // 8) for z in range(8)},
"zone_utilization": {z: util for z in range(8)},
"n_broken_stations": n_broken,
"lunch_active": False,
"surge_multiplier": 1.0,
"completed_so_far": 0,
"waiting_jobs": waiting,
"completed_jobs": [],
"all_jobs": {j.job_id: j for j in waiting},
}
def _wide_ranges():
"""Permissive ranges so OOD does NOT fire on baseline state. Tests that
target OOD override specific entries."""
return {n: (-1e6, 1e6) for n in SCENARIO_FEATURE_NAMES}
def _selector_with_ranges(ranges=None):
fe = FeatureExtractor()
fe.set_feature_ranges(ranges if ranges is not None else _wide_ranges())
return BatchwiseSelector(
model=_StubModel(),
feature_extractor=fe,
feature_importances=np.ones(N_FEATS) / N_FEATS,
feature_names=list(SCENARIO_FEATURE_NAMES),
)
def test_trivial_load_forces_fifo():
sel = _selector_with_ranges()
sel.update_state(_state(n_orders=3, util=0.4))
sel._reevaluate(now=0.0)
assert sel._current_heuristic == "fifo"
last = sel.switching_log.entries[-1]
assert last["reason"].startswith("guardrail")
def test_overload_locks_to_atc():
sel = _selector_with_ranges()
sel.update_state(_state(n_orders=80, util=0.95))
sel._reevaluate(now=0.0)
assert sel._current_heuristic == "atc"
last = sel.switching_log.entries[-1]
assert last["reason"].startswith("guardrail")
def test_ood_falls_back_to_atc():
# Ranges where util is in [0.10, 0.70] but state has util=0.85 (>10% over).
ranges = _wide_ranges()
ranges["zone_utilization_avg"] = (0.10, 0.70)
sel = _selector_with_ranges(ranges)
sel.update_state(_state(n_orders=50, util=0.85))
sel._reevaluate(now=0.0)
# Trivial load false (50 > 5), overload false (<0.92), so OOD must fire.
assert sel._current_heuristic == "atc"
last = sel.switching_log.entries[-1]
assert last["reason"].startswith("guardrail")
def test_no_guardrail_uses_ml_choice():
sel = _selector_with_ranges()
sel.update_state(_state(n_orders=30, util=0.5))
sel._reevaluate(now=0.0)
assert sel._current_heuristic == "wspt" # the stub's argmax
last = sel.switching_log.entries[-1]
assert last["reason"] == "ml_decision"
# ---------------------------------------------------------------------------
# Starvation prevention
# ---------------------------------------------------------------------------
@dataclass
class _FakeJob:
job_id: int
arrival_time: float
job_type: str = "C"
due_date: float = 9999.0
operations: list = field(default_factory=list)
current_op_idx: int = 0
status: str = "waiting"
@property
def is_complete(self):
return False
def remaining_proc_time(self):
return 5.0
def test_starvation_promotes_old_jobs_to_front():
sel = _selector_with_ranges()
sel._current_heuristic = "fifo" # simple ordering for the assertion
now = 200.0
young = _FakeJob(job_id=1, arrival_time=now - 10.0)
old = _FakeJob(job_id=2, arrival_time=now - 90.0)
middle = _FakeJob(job_id=3, arrival_time=now - 30.0)
ordered = sel.dispatch([young, middle, old], current_time=now, zone_id=0)
assert ordered[0].job_id == 2
rest_ids = [j.job_id for j in ordered[1:]]
assert rest_ids == [3, 1]