File size: 4,952 Bytes
906e104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
"""Tests for the four BatchwiseSelector edge-case guardrails.

Guardrails:
  - Trivial load (< TRIVIAL_LOAD jobs)         β†’ FIFO
  - Overload (avg utilization > OVERLOAD_THRESHOLD) β†’ ATC
  - Out-of-distribution (>10% beyond range)    β†’ ATC
  - Starvation (job waiting > STARVATION_LIMIT)β†’ force-promote in dispatch
"""
from __future__ import annotations

from dataclasses import dataclass, field
from typing import List

import numpy as np

from src.features import FeatureExtractor, SCENARIO_FEATURE_NAMES
from src.hybrid_scheduler import BatchwiseSelector

N_FEATS = len(SCENARIO_FEATURE_NAMES)


class _StubModel:
    """Always votes for WSPT (idx 4) β€” so any time we see ATC/FIFO selected,
    a guardrail must have fired."""
    def predict_proba(self, X):
        proba = np.zeros((1, 6))
        proba[0, 4] = 1.0
        return proba


@dataclass
class _Op:
    zone_id: int = 0
    nominal_proc_time: float = 5.0


@dataclass
class _MiniJob:
    job_id: int
    job_type: str = "C"
    arrival_time: float = 0.0
    due_date: float = 100.0
    operations: list = field(default_factory=lambda: [_Op()])
    current_op_idx: int = 0
    status: str = "waiting"
    completion_time: float = -1.0

    @property
    def is_complete(self):
        return False

    def remaining_proc_time(self):
        return 5.0


def _state(n_orders=50, util=0.5, n_broken=0):
    waiting = [_MiniJob(job_id=i, due_date=100.0 + 50 * i) for i in range(min(n_orders, 4))]
    return {
        "current_time": 10.0,
        "n_orders_in_system": n_orders,
        "queue_sizes": {z: max(1, n_orders // 8) for z in range(8)},
        "zone_utilization": {z: util for z in range(8)},
        "n_broken_stations": n_broken,
        "lunch_active": False,
        "surge_multiplier": 1.0,
        "completed_so_far": 0,
        "waiting_jobs": waiting,
        "completed_jobs": [],
        "all_jobs": {j.job_id: j for j in waiting},
    }


def _wide_ranges():
    """Permissive ranges so OOD does NOT fire on baseline state. Tests that
    target OOD override specific entries."""
    return {n: (-1e6, 1e6) for n in SCENARIO_FEATURE_NAMES}


def _selector_with_ranges(ranges=None):
    fe = FeatureExtractor()
    fe.set_feature_ranges(ranges if ranges is not None else _wide_ranges())
    return BatchwiseSelector(
        model=_StubModel(),
        feature_extractor=fe,
        feature_importances=np.ones(N_FEATS) / N_FEATS,
        feature_names=list(SCENARIO_FEATURE_NAMES),
    )


def test_trivial_load_forces_fifo():
    sel = _selector_with_ranges()
    sel.update_state(_state(n_orders=3, util=0.4))
    sel._reevaluate(now=0.0)
    assert sel._current_heuristic == "fifo"
    last = sel.switching_log.entries[-1]
    assert last["reason"].startswith("guardrail")


def test_overload_locks_to_atc():
    sel = _selector_with_ranges()
    sel.update_state(_state(n_orders=80, util=0.95))
    sel._reevaluate(now=0.0)
    assert sel._current_heuristic == "atc"
    last = sel.switching_log.entries[-1]
    assert last["reason"].startswith("guardrail")


def test_ood_falls_back_to_atc():
    # Ranges where util is in [0.10, 0.70] but state has util=0.85 (>10% over).
    ranges = _wide_ranges()
    ranges["zone_utilization_avg"] = (0.10, 0.70)
    sel = _selector_with_ranges(ranges)
    sel.update_state(_state(n_orders=50, util=0.85))
    sel._reevaluate(now=0.0)
    # Trivial load false (50 > 5), overload false (<0.92), so OOD must fire.
    assert sel._current_heuristic == "atc"
    last = sel.switching_log.entries[-1]
    assert last["reason"].startswith("guardrail")


def test_no_guardrail_uses_ml_choice():
    sel = _selector_with_ranges()
    sel.update_state(_state(n_orders=30, util=0.5))
    sel._reevaluate(now=0.0)
    assert sel._current_heuristic == "wspt"  # the stub's argmax
    last = sel.switching_log.entries[-1]
    assert last["reason"] == "ml_decision"


# ---------------------------------------------------------------------------
# Starvation prevention
# ---------------------------------------------------------------------------

@dataclass
class _FakeJob:
    job_id: int
    arrival_time: float
    job_type: str = "C"
    due_date: float = 9999.0
    operations: list = field(default_factory=list)
    current_op_idx: int = 0
    status: str = "waiting"

    @property
    def is_complete(self):
        return False

    def remaining_proc_time(self):
        return 5.0


def test_starvation_promotes_old_jobs_to_front():
    sel = _selector_with_ranges()
    sel._current_heuristic = "fifo"  # simple ordering for the assertion
    now = 200.0
    young = _FakeJob(job_id=1, arrival_time=now - 10.0)
    old = _FakeJob(job_id=2, arrival_time=now - 90.0)
    middle = _FakeJob(job_id=3, arrival_time=now - 30.0)
    ordered = sel.dispatch([young, middle, old], current_time=now, zone_id=0)
    assert ordered[0].job_id == 2
    rest_ids = [j.job_id for j in ordered[1:]]
    assert rest_ids == [3, 1]