File size: 4,074 Bytes
906e104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""Hysteresis tests for BatchwiseSelector.

DAHS_2.1 uses a *relative* margin: a switch is blocked unless the new
heuristic's confidence exceeds the current's by at least HYSTERESIS_MARGIN
(15% relative, calibration-invariant across RF and XGB predict_proba).

We feed the selector synthetic state with a stub model whose probabilities
we fully control, then assert the switch fires iff
    new_conf >= current_conf * (1 + HYSTERESIS_MARGIN).
"""
from __future__ import annotations

import numpy as np
import pytest

from src.features import FeatureExtractor, SCENARIO_FEATURE_NAMES
from src.hybrid_scheduler import BatchwiseSelector


N_FEATS = len(SCENARIO_FEATURE_NAMES)


class _StubModel:
    """Returns whatever predict_proba vector the test sets on .next_proba."""
    def __init__(self):
        self.next_proba = np.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0])

    def predict_proba(self, X):
        return self.next_proba.reshape(1, -1)


def _scenario_state(n_orders=50, util=0.5, n_broken=0, lunch=False):
    return {
        "current_time": 10.0,
        "n_orders_in_system": n_orders,
        "queue_sizes": {z: 5 for z in range(8)},
        "zone_utilization": {z: util for z in range(8)},
        "n_broken_stations": n_broken,
        "lunch_active": lunch,
        "surge_multiplier": 1.0,
        "completed_so_far": 0,
        "waiting_jobs": [],
        "completed_jobs": [],
        "all_jobs": {},
    }


def _make_selector():
    fe = FeatureExtractor()
    sel = BatchwiseSelector(
        model=_StubModel(),
        feature_extractor=fe,
        feature_importances=np.ones(N_FEATS) / N_FEATS,
        feature_names=list(SCENARIO_FEATURE_NAMES),
    )
    return sel, sel._model


def test_initial_evaluation_picks_argmax():
    sel, model = _make_selector()
    model.next_proba = np.array([0.05, 0.05, 0.05, 0.05, 0.75, 0.05])
    sel.update_state(_scenario_state())
    sel._reevaluate(now=0.0)
    assert sel._current_heuristic == "wspt"
    assert sel._current_confidence == pytest.approx(0.75, abs=1e-6)


def test_small_relative_advantage_blocked_by_hysteresis():
    sel, model = _make_selector()
    # Lock in WSPT at 0.55
    model.next_proba = np.array([0.05, 0.05, 0.05, 0.10, 0.55, 0.20])
    sel.update_state(_scenario_state())
    sel._reevaluate(now=0.0)
    assert sel._current_heuristic == "wspt"
    locked = sel._current_confidence  # 0.55

    # New top: Slack at 0.60. Threshold = 0.55 * 1.15 = 0.6325 → blocked.
    model.next_proba = np.array([0.05, 0.05, 0.05, 0.10, 0.15, 0.60])
    sel.update_state(_scenario_state())
    sel._reevaluate(now=20.0)
    assert sel._current_heuristic == "wspt", "Hysteresis should have blocked the switch"
    last = sel.switching_log.entries[-1]
    assert last["reason"] == "hysteresis_blocked"
    assert last["switched"] is False


def test_large_relative_advantage_triggers_switch():
    sel, model = _make_selector()
    # Lock in WSPT at 0.45
    model.next_proba = np.array([0.05, 0.05, 0.10, 0.10, 0.45, 0.25])
    sel.update_state(_scenario_state())
    sel._reevaluate(now=0.0)
    assert sel._current_heuristic == "wspt"
    locked = sel._current_confidence  # 0.45

    # ATC at 0.85 — 0.85 > 0.45 * 1.15 = 0.5175, so switch should fire.
    new_top = 0.85
    rest = (1.0 - new_top) / 5
    proba = np.full(6, rest)
    proba[3] = new_top
    model.next_proba = proba
    sel.update_state(_scenario_state())
    sel._reevaluate(now=20.0)
    assert sel._current_heuristic == "atc"
    last = sel.switching_log.entries[-1]
    assert last["switched"] is True
    assert last["reason"] == "ml_decision"


def test_same_heuristic_keeps_no_hysteresis_block():
    """If the model picks the same heuristic again, that's a 'hold', not a block."""
    sel, model = _make_selector()
    model.next_proba = np.array([0.05, 0.05, 0.05, 0.10, 0.55, 0.20])
    sel.update_state(_scenario_state())
    sel._reevaluate(now=0.0)
    sel._reevaluate(now=20.0)
    last = sel.switching_log.entries[-1]
    assert last["selected"] == "wspt"
    assert last["reason"] != "hysteresis_blocked"