Spaces:
Sleeping
Sleeping
| """Tests for the four BatchwiseSelector edge-case guardrails. | |
| Guardrails: | |
| - Trivial load (< TRIVIAL_LOAD jobs) → FIFO | |
| - Overload (avg utilization > OVERLOAD_THRESHOLD) → ATC | |
| - Out-of-distribution (>10% beyond range) → ATC | |
| - Starvation (job waiting > STARVATION_LIMIT)→ force-promote in dispatch | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass, field | |
| from typing import List | |
| import numpy as np | |
| from src.features import FeatureExtractor, SCENARIO_FEATURE_NAMES | |
| from src.hybrid_scheduler import BatchwiseSelector | |
| N_FEATS = len(SCENARIO_FEATURE_NAMES) | |
| class _StubModel: | |
| """Always votes for WSPT (idx 4) — so any time we see ATC/FIFO selected, | |
| a guardrail must have fired.""" | |
| def predict_proba(self, X): | |
| proba = np.zeros((1, 6)) | |
| proba[0, 4] = 1.0 | |
| return proba | |
| class _Op: | |
| zone_id: int = 0 | |
| nominal_proc_time: float = 5.0 | |
| class _MiniJob: | |
| job_id: int | |
| job_type: str = "C" | |
| arrival_time: float = 0.0 | |
| due_date: float = 100.0 | |
| operations: list = field(default_factory=lambda: [_Op()]) | |
| current_op_idx: int = 0 | |
| status: str = "waiting" | |
| completion_time: float = -1.0 | |
| def is_complete(self): | |
| return False | |
| def remaining_proc_time(self): | |
| return 5.0 | |
| def _state(n_orders=50, util=0.5, n_broken=0): | |
| waiting = [_MiniJob(job_id=i, due_date=100.0 + 50 * i) for i in range(min(n_orders, 4))] | |
| return { | |
| "current_time": 10.0, | |
| "n_orders_in_system": n_orders, | |
| "queue_sizes": {z: max(1, n_orders // 8) for z in range(8)}, | |
| "zone_utilization": {z: util for z in range(8)}, | |
| "n_broken_stations": n_broken, | |
| "lunch_active": False, | |
| "surge_multiplier": 1.0, | |
| "completed_so_far": 0, | |
| "waiting_jobs": waiting, | |
| "completed_jobs": [], | |
| "all_jobs": {j.job_id: j for j in waiting}, | |
| } | |
| def _wide_ranges(): | |
| """Permissive ranges so OOD does NOT fire on baseline state. Tests that | |
| target OOD override specific entries.""" | |
| return {n: (-1e6, 1e6) for n in SCENARIO_FEATURE_NAMES} | |
| def _selector_with_ranges(ranges=None): | |
| fe = FeatureExtractor() | |
| fe.set_feature_ranges(ranges if ranges is not None else _wide_ranges()) | |
| return BatchwiseSelector( | |
| model=_StubModel(), | |
| feature_extractor=fe, | |
| feature_importances=np.ones(N_FEATS) / N_FEATS, | |
| feature_names=list(SCENARIO_FEATURE_NAMES), | |
| ) | |
| def test_trivial_load_forces_fifo(): | |
| sel = _selector_with_ranges() | |
| sel.update_state(_state(n_orders=3, util=0.4)) | |
| sel._reevaluate(now=0.0) | |
| assert sel._current_heuristic == "fifo" | |
| last = sel.switching_log.entries[-1] | |
| assert last["reason"].startswith("guardrail") | |
| def test_overload_locks_to_atc(): | |
| sel = _selector_with_ranges() | |
| sel.update_state(_state(n_orders=80, util=0.95)) | |
| sel._reevaluate(now=0.0) | |
| assert sel._current_heuristic == "atc" | |
| last = sel.switching_log.entries[-1] | |
| assert last["reason"].startswith("guardrail") | |
| def test_ood_falls_back_to_atc(): | |
| # Ranges where util is in [0.10, 0.70] but state has util=0.85 (>10% over). | |
| ranges = _wide_ranges() | |
| ranges["zone_utilization_avg"] = (0.10, 0.70) | |
| sel = _selector_with_ranges(ranges) | |
| sel.update_state(_state(n_orders=50, util=0.85)) | |
| sel._reevaluate(now=0.0) | |
| # Trivial load false (50 > 5), overload false (<0.92), so OOD must fire. | |
| assert sel._current_heuristic == "atc" | |
| last = sel.switching_log.entries[-1] | |
| assert last["reason"].startswith("guardrail") | |
| def test_no_guardrail_uses_ml_choice(): | |
| sel = _selector_with_ranges() | |
| sel.update_state(_state(n_orders=30, util=0.5)) | |
| sel._reevaluate(now=0.0) | |
| assert sel._current_heuristic == "wspt" # the stub's argmax | |
| last = sel.switching_log.entries[-1] | |
| assert last["reason"] == "ml_decision" | |
| # --------------------------------------------------------------------------- | |
| # Starvation prevention | |
| # --------------------------------------------------------------------------- | |
| class _FakeJob: | |
| job_id: int | |
| arrival_time: float | |
| job_type: str = "C" | |
| due_date: float = 9999.0 | |
| operations: list = field(default_factory=list) | |
| current_op_idx: int = 0 | |
| status: str = "waiting" | |
| def is_complete(self): | |
| return False | |
| def remaining_proc_time(self): | |
| return 5.0 | |
| def test_starvation_promotes_old_jobs_to_front(): | |
| sel = _selector_with_ranges() | |
| sel._current_heuristic = "fifo" # simple ordering for the assertion | |
| now = 200.0 | |
| young = _FakeJob(job_id=1, arrival_time=now - 10.0) | |
| old = _FakeJob(job_id=2, arrival_time=now - 90.0) | |
| middle = _FakeJob(job_id=3, arrival_time=now - 30.0) | |
| ordered = sel.dispatch([young, middle, old], current_time=now, zone_id=0) | |
| assert ordered[0].job_id == 2 | |
| rest_ids = [j.job_id for j in ordered[1:]] | |
| assert rest_ids == [3, 1] | |