File size: 36,127 Bytes
906e104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
"""

hybrid_scheduler.py β€” Batch-wise ML Hybrid Scheduler with Guardrails (DAHS_2)



NEW architecture vs DAHS_1:

  - BatchwiseSelector: re-evaluates every 15 min OR on disruption events

  - Hysteresis: only switches if >15% more confident

  - Edge case guardrails: trivial load, overload, OOD detection

  - Starvation prevention: force-promote jobs waiting >60 min

  - 3-level interpretability log per evaluation

  - Plain English explanations



Also includes (ported from DAHS_1):

  - SwitchingLog class

  - HybridPriority class

  - Factory functions

"""

from __future__ import annotations

import logging
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import joblib
import numpy as np

logger = logging.getLogger(__name__)

MODELS_DIR = Path(__file__).parent.parent / "models"


# ---------------------------------------------------------------------------
# Switching Log (enhanced for DAHS_2 with evaluation payload)
# ---------------------------------------------------------------------------

class SwitchingLog:
    """Records every batch-wise heuristic-selection evaluation made by BatchwiseSelector.



    DAHS_2: Each entry contains full evaluation context including probabilities,

    top features, reason, and plain-English explanation.

    """

    HEURISTIC_NAMES = ["fifo", "priority_edd", "critical_ratio", "atc", "wspt", "slack"]

    def __init__(self) -> None:
        self.entries: List[Dict[str, Any]] = []
        self._last_heuristic: Optional[str] = None
        self._switch_count: int = 0
        self._hysteresis_blocked: int = 0
        self._guardrail_activations: int = 0

    def record(

        self,

        time: float,

        features: List[float],

        probabilities: Dict[str, float],

        selected: str,

        switched: bool,

        reason: str,

        confidence: float,

        top_features: List[Dict[str, Any]],

        plain_english: str,

    ) -> None:
        """Record one batch evaluation."""
        if switched:
            self._switch_count += 1
        if reason == "hysteresis_blocked":
            self._hysteresis_blocked += 1
        if reason.startswith("guardrail"):
            self._guardrail_activations += 1
        self._last_heuristic = selected

        self.entries.append({
            "time": round(time, 2),
            "features": [round(float(f), 4) for f in features],
            "probabilities": {k: round(float(v), 4) for k, v in probabilities.items()},
            "selected": selected,
            "switched": switched,
            "reason": reason,
            "confidence": round(confidence, 4),
            "topFeatures": top_features,
            "plainEnglish": plain_english,
        })

    @property
    def total_evaluations(self) -> int:
        return len(self.entries)

    @property
    def switch_count(self) -> int:
        return self._switch_count

    def heuristic_distribution(self) -> Dict[str, float]:
        """Fraction of evaluations assigned to each heuristic."""
        if not self.entries:
            return {}
        counts: Dict[str, int] = {}
        for e in self.entries:
            h = e["selected"]
            counts[h] = counts.get(h, 0) + 1
        total = len(self.entries)
        return {h: c / total for h, c in sorted(counts.items())}

    def switching_rate(self) -> float:
        """Switches per evaluation."""
        if len(self.entries) < 2:
            return 0.0
        return self._switch_count / (len(self.entries) - 1)

    def summary(self) -> Dict[str, Any]:
        """Return a human-readable summary dict."""
        dist = self.heuristic_distribution()
        return {
            "totalEvaluations": self.total_evaluations,
            "switchCount": self._switch_count,
            "switchingRate": round(self.switching_rate(), 4),
            "hysteresisBlocked": self._hysteresis_blocked,
            "guardrailActivations": self._guardrail_activations,
            "distribution": {k: round(v, 4) for k, v in dist.items()},
            "dominantHeuristic": max(dist, key=dist.get) if dist else "none",
        }

    def to_list(self) -> List[Dict[str, Any]]:
        """Return entries as a plain list for JSON serialization."""
        return self.entries


# ---------------------------------------------------------------------------
# BatchwiseSelector β€” Core DAHS_2 scheduler
# ---------------------------------------------------------------------------

class BatchwiseSelector:
    """Batch-wise ML heuristic selector with guardrails and hysteresis.



    Re-evaluates every 15 minutes OR on disruption events (breakdown,

    batch arrival, lunch state change). Only switches if new heuristic

    is >15% more confident (hysteresis).



    Edge-case guardrails:

    - Trivial: n_orders < 5 β†’ use FIFO

    - Overload: avg_utilization > 0.92 β†’ lock to ATC + alert

    - OOD: features outside training range Β±10% β†’ safe fallback to ATC

    - Starvation: any job waiting >60 min β†’ force-promote

    """

    EVAL_INTERVAL      = 15.0   # minutes between re-evaluations
    # Relative margin: new heuristic's probability must exceed current Γ— (1 + margin).
    # Calibration-invariant across RF (broad) and XGB (sharp) predict_proba outputs.
    HYSTERESIS_MARGIN  = 0.15
    TRIVIAL_LOAD       = 5       # skip ML if fewer jobs
    OVERLOAD_THRESHOLD = 0.92    # lock to ATC
    STARVATION_LIMIT   = 60.0    # force-promote starving jobs (minutes)

    HEURISTIC_MAP = {
        0: "fifo", 1: "priority_edd", 2: "critical_ratio",
        3: "atc",  4: "wspt",         5: "slack",
    }
    HEURISTIC_LABELS = {
        "fifo": "FIFO", "priority_edd": "Priority-EDD",
        "critical_ratio": "Critical-Ratio", "atc": "ATC",
        "wspt": "WSPT", "slack": "Slack",
    }

    # Plain-English reason templates
    _EXPLANATION_MAP = {
        ("atc",            "time_pressure_ratio"):  "many jobs are nearing their deadlines",
        ("atc",            "surge_multiplier"):      "demand surging above normal rate",
        ("atc",            "zone_utilization_avg"):  "warehouse is highly loaded",
        ("critical_ratio", "n_broken_stations"):     "station breakdowns are causing bottlenecks",
        ("critical_ratio", "disruption_intensity"):  "high disruption intensity detected",
        ("fifo",           "zone_utilization_avg"):  "load is light, simple ordering is optimal",
        ("fifo",           "n_orders_in_system"):    "few jobs in system, FIFO is stable",
        ("wspt",           "avg_priority_weight"):   "high-value short jobs should be prioritized",
        ("wspt",           "avg_remaining_proc_time"): "many short jobs in queue",
        ("priority_edd",   "n_express_orders_pct"):  "high fraction of express orders",
        ("priority_edd",   "fraction_already_late"): "many jobs past due date",
        ("slack",          "avg_due_date_tightness"): "deadlines are extremely tight",
        ("slack",          "sla_breach_rate_current"): "SLA breach rate is rising",
    }

    def __init__(

        self,

        model: Any,

        feature_extractor: Any,

        feature_importances: Optional[np.ndarray] = None,

        feature_names: Optional[List[str]] = None,

    ) -> None:
        self._model = model
        self._fe = feature_extractor
        self._feature_importances = feature_importances
        self._feature_names = feature_names or []

        self._current_heuristic: str = "fifo"
        self._current_confidence: float = 0.0
        self._current_from_guardrail: bool = False
        self._last_eval_time: float = -999.0
        self._last_breakdown_count: int = 0
        self._last_lunch_state: bool = False

        self.switching_log = SwitchingLog()
        self._sim_state: Optional[Dict[str, Any]] = None

    def update_state(self, sim_state: Dict[str, Any]) -> None:
        """Update stored simulation state (called before dispatch)."""
        self._sim_state = sim_state

    # ------------------------------------------------------------------
    # Main dispatch interface
    # ------------------------------------------------------------------

    def dispatch(

        self,

        jobs: List[Any],

        current_time: float,

        zone_id: int,

    ) -> List[Any]:
        """Apply current heuristic, potentially re-evaluating first.



        This is the main entry point called by the simulator's heuristic_fn.

        Re-evaluates every 15 min or on disruption events.

        """
        from src.heuristics import (
            fifo_dispatch, priority_edd_dispatch, critical_ratio_dispatch,
            atc_dispatch, wspt_dispatch, slack_dispatch,
        )

        dispatch_fns: Dict[str, Callable] = {
            "fifo": fifo_dispatch,
            "priority_edd": priority_edd_dispatch,
            "critical_ratio": critical_ratio_dispatch,
            "atc": atc_dispatch,
            "wspt": wspt_dispatch,
            "slack": slack_dispatch,
        }

        if not jobs:
            return jobs

        # Re-evaluate if needed (time-based or event-triggered)
        if self._sim_state is not None and self._should_reevaluate(current_time):
            self._reevaluate(current_time)

        # Starvation prevention: force-promote any job waiting >60 min
        fn = dispatch_fns.get(self._current_heuristic, fifo_dispatch)
        ordered = fn(jobs, current_time, zone_id)
        ordered = self._apply_starvation_prevention(ordered, current_time)

        return ordered

    def __call__(self, jobs: List[Any], current_time: float, zone_id: int) -> List[Any]:
        """Callable interface (same as dispatch)."""
        return self.dispatch(jobs, current_time, zone_id)

    # ------------------------------------------------------------------
    # Re-evaluation logic
    # ------------------------------------------------------------------

    def _should_reevaluate(self, now: float) -> bool:
        """Return True if we should re-evaluate the heuristic selection."""
        if self._sim_state is None:
            return False

        # Time-based: every 15 minutes
        if now - self._last_eval_time >= self.EVAL_INTERVAL:
            return True

        # Event: breakdown count changed
        n_broken = self._sim_state.get("n_broken_stations", 0)
        if n_broken != self._last_breakdown_count:
            return True

        # Event: lunch state changed
        lunch = self._sim_state.get("lunch_active", False)
        if lunch != self._last_lunch_state:
            return True

        return False

    def _reevaluate(self, now: float) -> None:
        """Perform ML evaluation and decide whether to switch heuristic."""
        if self._sim_state is None:
            return

        self._last_eval_time = now
        self._last_breakdown_count = self._sim_state.get("n_broken_stations", 0)
        self._last_lunch_state = self._sim_state.get("lunch_active", False)

        # Extract features
        try:
            features = self._fe.extract_scenario_features(self._sim_state)
        except Exception as e:
            logger.warning("Feature extraction failed: %s", e)
            return

        # Check guardrails first
        guardrail = self._check_guardrails(features)
        if guardrail is not None:
            # Guardrail triggered β€” record and switch if needed
            switched = guardrail != self._current_heuristic
            plain = f"Guardrail active: {guardrail.replace('guardrail_', '')}. Using {guardrail} as safe default."
            probas = {h: (1.0 if h == guardrail else 0.0) for h in self.HEURISTIC_MAP.values()}
            top_features = self._get_top_features(features, n=5)

            reason_map = {
                "fifo": "guardrail_trivial",
                "atc": "guardrail_overload" if self._sim_state.get("zone_utilization", {}) else "guardrail_ood",
            }
            reason = reason_map.get(guardrail, f"guardrail_{guardrail}")

            self.switching_log.record(
                time=now,
                features=features.tolist(),
                probabilities=probas,
                selected=guardrail,
                switched=switched,
                reason=reason,
                confidence=1.0,
                top_features=top_features,
                plain_english=f"Guardrail active. Using {self.HEURISTIC_LABELS.get(guardrail, guardrail)} as safe default.",
            )
            self._current_heuristic = guardrail
            self._current_confidence = 1.0
            self._current_from_guardrail = True
            return

        # ML prediction
        try:
            X = features.reshape(1, -1)
            probas_arr = self._model.predict_proba(X)[0]
            new_idx = int(np.argmax(probas_arr))
            new_heuristic = self.HEURISTIC_MAP.get(new_idx, "fifo")
            new_confidence = float(probas_arr[new_idx])

            probas_dict = {
                self.HEURISTIC_MAP[i]: float(p)
                for i, p in enumerate(probas_arr)
                if i in self.HEURISTIC_MAP
            }

        except Exception as e:
            logger.warning("ML prediction failed: %s", e)
            return

        # Relative-margin hysteresis: switch only if the new heuristic's probability
        # exceeds the current Γ— (1 + HYSTERESIS_MARGIN). This is calibration-invariant
        # across RF (broad probs) and XGB (sharp probs), unlike an additive threshold.
        # Bypassed when current was forced by a guardrail (prevents lock-in on FIFO
        # at t=0 when system was empty).
        if (not self._current_from_guardrail
                and new_heuristic != self._current_heuristic
                and new_confidence < self._current_confidence * (1.0 + self.HYSTERESIS_MARGIN)):
            # Blocked by hysteresis
            top_features = self._get_top_features(features, n=5)
            self.switching_log.record(
                time=now,
                features=features.tolist(),
                probabilities=probas_dict,
                selected=self._current_heuristic,
                switched=False,
                reason="hysteresis_blocked",
                confidence=new_confidence,
                top_features=top_features,
                plain_english=(
                    f"ML suggests {self.HEURISTIC_LABELS.get(new_heuristic, new_heuristic)} "
                    f"({new_confidence:.0%} confident) but hysteresis threshold not met. "
                    f"Keeping {self.HEURISTIC_LABELS.get(self._current_heuristic, self._current_heuristic)}."
                ),
            )
            return

        # Switch (or keep) accepted
        switched = new_heuristic != self._current_heuristic
        top_features = self._get_top_features(features, n=5)
        plain_english = self._generate_explanation(features, new_heuristic, "ml_decision", probas_dict)

        self.switching_log.record(
            time=now,
            features=features.tolist(),
            probabilities=probas_dict,
            selected=new_heuristic,
            switched=switched,
            reason="ml_decision",
            confidence=new_confidence,
            top_features=top_features,
            plain_english=plain_english,
        )

        self._current_heuristic = new_heuristic
        self._current_confidence = new_confidence
        self._current_from_guardrail = False

    def _check_guardrails(self, features: np.ndarray) -> Optional[str]:
        """Check edge-case guardrails. Returns heuristic name or None."""
        from src.features import SCENARIO_FEATURE_NAMES

        feat_dict = dict(zip(SCENARIO_FEATURE_NAMES, features.tolist()))

        # Guardrail 1: Trivial load
        n_orders = feat_dict.get("n_orders_in_system", 0)
        if n_orders < self.TRIVIAL_LOAD:
            return "fifo"

        # Guardrail 2: Overload
        util_avg = feat_dict.get("zone_utilization_avg", 0.0)
        if util_avg > self.OVERLOAD_THRESHOLD:
            return "atc"

        # Guardrail 3: OOD detection
        if self._fe._feature_ranges is not None:
            if self._fe.is_out_of_distribution(features, tolerance=0.10):
                return "atc"

        return None

    def _apply_starvation_prevention(

        self,

        jobs: List[Any],

        current_time: float,

    ) -> List[Any]:
        """Force-promote jobs that have been waiting >60 minutes.



        Moves starving jobs to the front of the queue regardless of heuristic.

        """
        starving = [j for j in jobs if (current_time - j.arrival_time) > self.STARVATION_LIMIT]
        non_starving = [j for j in jobs if j not in starving]
        return starving + non_starving

    def _get_top_features(self, features: np.ndarray, n: int = 5) -> List[Dict[str, Any]]:
        """Return top-n features by importance with current values."""
        from src.features import SCENARIO_FEATURE_NAMES

        feat_names = self._feature_names or SCENARIO_FEATURE_NAMES

        if self._feature_importances is not None:
            top_idx = np.argsort(self._feature_importances)[::-1][:n]
        else:
            top_idx = list(range(min(n, len(feat_names))))

        result = []
        for i in top_idx:
            if i < len(feat_names) and i < len(features):
                result.append({
                    "name": feat_names[i],
                    "value": round(float(features[i]), 4),
                    "importance": round(float(self._feature_importances[i]), 4)
                    if self._feature_importances is not None else 0.0,
                })
        return result

    def _generate_explanation(

        self,

        features: np.ndarray,

        heuristic: str,

        reason: str,

        probas: Dict[str, float],

    ) -> str:
        """Generate a plain-English explanation for THIS specific decision.



        Rather than citing the globally most-important feature (which would

        be identical across every decision), we pick the feature whose

        per-decision contribution is highest. Contribution is approximated as

        importance Γ— |z-score of current value against training range|.

        """
        from src.features import SCENARIO_FEATURE_NAMES

        feat_names = self._feature_names or list(SCENARIO_FEATURE_NAMES)
        feat_dict = dict(zip(feat_names, features.tolist()))
        label = self.HEURISTIC_LABELS.get(heuristic, heuristic)
        confidence = probas.get(heuristic, 0.0)

        # Try to find a per-decision salient feature that has an explanation
        # template for this heuristic.
        if self._feature_importances is not None and len(feat_names) > 0:
            ranges = getattr(self._fe, "_feature_ranges", None) or {}
            # Compute a salience score per feature: importance Γ— normalized deviation
            salience = np.zeros(len(feat_names), dtype=float)
            for i, name in enumerate(feat_names):
                if i >= len(features) or i >= len(self._feature_importances):
                    continue
                val = float(features[i])
                imp = float(self._feature_importances[i])
                lo_hi = ranges.get(name)
                if lo_hi and lo_hi[1] > lo_hi[0]:
                    mid = 0.5 * (lo_hi[0] + lo_hi[1])
                    half = 0.5 * (lo_hi[1] - lo_hi[0])
                    deviation = abs(val - mid) / max(half, 1e-6)
                else:
                    deviation = 1.0  # no range info -> fall back to importance only
                salience[i] = imp * (0.5 + deviation)  # floor keeps importance relevant

            # Prefer features that have a template for this heuristic
            ranked = np.argsort(salience)[::-1]
            for idx in ranked[:8]:  # look at top 8 salient features
                if idx >= len(feat_names):
                    continue
                fname = feat_names[idx]
                key = (heuristic, fname)
                if key in self._EXPLANATION_MAP:
                    reason_str = self._EXPLANATION_MAP[key]
                    val = feat_dict.get(fname, 0.0)
                    return (
                        f"DAHS selected {label} ({confidence:.0%} confidence) because "
                        f"{reason_str} ({fname}={val:.2f})."
                    )

            # No template hit β€” name the most salient feature generically
            if ranked.size > 0:
                idx0 = int(ranked[0])
                if idx0 < len(feat_names):
                    fname = feat_names[idx0]
                    val = feat_dict.get(fname, 0.0)
                    return (
                        f"DAHS selected {label} with {confidence:.0%} confidence; "
                        f"the strongest driver for this decision was "
                        f"{fname}={val:.2f}."
                    )

        # Generic fallback
        return (
            f"DAHS selected {label} with {confidence:.0%} confidence based on "
            f"current system state. This is the predicted optimal heuristic for "
            f"minimizing weighted tardiness and SLA breaches."
        )


# ---------------------------------------------------------------------------
# HybridPriority (ported from DAHS_1)
# ---------------------------------------------------------------------------

class HybridPriority:
    """Wraps a trained GBR priority-predictor regressor."""

    def __init__(

        self,

        model_path: Union[Path, str],

        feature_extractor: Any,

    ) -> None:
        self.model_path = Path(model_path)
        self.feature_extractor = feature_extractor
        self._model = joblib.load(self.model_path)
        self._sim_state: Optional[Dict[str, Any]] = None
        logger.info("HybridPriority loaded model from %s", self.model_path)

    def update_state(self, sim_state: Dict[str, Any]) -> None:
        self._sim_state = sim_state

    def __call__(

        self,

        jobs: List[Any],

        current_time: float,

        zone_id: int,

    ) -> List[Any]:
        """Dispatch jobs by predicted priority score (descending)."""
        from src.heuristics import fifo_dispatch

        if not jobs:
            return jobs

        if self._sim_state is None:
            return fifo_dispatch(jobs, current_time, zone_id)

        try:
            sf = self.feature_extractor.extract_scenario_features(self._sim_state)
            job_feats = np.stack([
                np.concatenate([sf, self.feature_extractor.extract_job_features(j, self._sim_state)])
                for j in jobs
            ])
            predictions = self._model.predict(job_feats)
            ranked = sorted(zip(predictions, jobs), key=lambda x: x[0], reverse=True)
            return [job for _, job in ranked]
        except Exception as exc:
            from src.heuristics import fifo_dispatch
            logger.warning("HybridPriority error: %s β€” falling back to FIFO", exc)
            return fifo_dispatch(jobs, current_time, zone_id)


# ---------------------------------------------------------------------------
# Rolling-Horizon Fork Oracle (DAHS 2.1) β€” hard performance guarantee
# ---------------------------------------------------------------------------

class RollingHorizonOracle:
    """Pure fork-oracle selector with a mathematical per-window guarantee.



    At each EVAL_INTERVAL minutes it clones the simulator via save_state,

    runs every heuristic forward for HORIZON minutes using the preserved RNG

    (so all forks see identical future arrivals), then picks the argmin of

    a composite cost matching the benchmark objective. Because forks are

    RNG-deterministic, the argmin per window is an exact oracle; summed

    over the day, cumulative cost is mathematically ≀ min-over-heuristics.



    Compute cost: 6 forks Γ— HORIZON min Γ— (600 / EVAL_INTERVAL) decisions β‰ˆ

    21,600 sim-min/day for H=90 β€” a constant multiplier on the base sim time.



    Usage:

        sim = WarehouseSimulator(seed=..., heuristic_fn=lambda j, t, z: j, ...)

        oracle = RollingHorizonOracle()

        oracle.attach_simulator(sim)

        sim.heuristic_fn = lambda jobs, t, z: oracle.dispatch(jobs, t, z)

        sim.run(duration=600.0)

    """

    EVAL_INTERVAL = 15.0
    HORIZON       = 90.0   # β‰₯ median job cycle (23 min Olist) Γ— 4 β€” eliminates myopia
    STARVATION_LIMIT = 60.0
    HEURISTIC_NAMES = ["fifo", "priority_edd", "critical_ratio", "atc", "wspt", "slack"]

    # Cost weights aligned with benchmark objective (tardiness-dominant)
    W_TARD = 0.55
    W_SLA  = 0.35
    W_CYC  = 0.10

    def __init__(self, ml_model: Optional[Any] = None, feature_extractor: Any = None) -> None:
        """Pure oracle when ml_model is None; hybrid (ML prior) when supplied."""
        self._ml_model = ml_model
        self._fe = feature_extractor
        self._sim: Optional[Any] = None
        self._current_heuristic: str = "fifo"
        self._last_eval_time: float = -999.0
        self._last_breakdown_count: int = 0
        self._last_lunch_state: bool = False
        self.switching_log = SwitchingLog()

    def attach_simulator(self, sim: Any) -> None:
        """Bind to the main simulator so we can snapshot it for forks."""
        self._sim = sim

    def __call__(self, jobs: List[Any], current_time: float, zone_id: int) -> List[Any]:
        return self.dispatch(jobs, current_time, zone_id)

    def dispatch(self, jobs: List[Any], current_time: float, zone_id: int) -> List[Any]:
        from src.heuristics import DISPATCH_MAP, fifo_dispatch

        if not jobs:
            return jobs

        # Re-evaluate every EVAL_INTERVAL minutes or on state-changing events
        if self._sim is not None and self._should_reevaluate(current_time):
            self._reevaluate(current_time)

        fn = DISPATCH_MAP.get(self._current_heuristic, fifo_dispatch)
        ordered = fn(jobs, current_time, zone_id)
        ordered = self._apply_starvation_prevention(ordered, current_time)
        return ordered

    # ------------------------------------------------------------------
    # Fork-oracle evaluation
    # ------------------------------------------------------------------

    def _should_reevaluate(self, now: float) -> bool:
        if self._sim is None:
            return False
        if now - self._last_eval_time >= self.EVAL_INTERVAL:
            return True
        # disruption events
        n_broken = sum(
            1 for st in getattr(self._sim, "stations", {}).values()
            if getattr(st, "is_broken", False)
        )
        if n_broken != self._last_breakdown_count:
            return True
        lunch = getattr(self._sim, "_lunch_active", False)
        if lunch != self._last_lunch_state:
            return True
        return False

    def _reevaluate(self, now: float) -> None:
        """Fork all heuristics, score, select best. Hard guarantee lives here."""
        from src.heuristics import DISPATCH_MAP
        from src.simulator import WarehouseSimulator

        self._last_eval_time = now
        self._last_breakdown_count = sum(
            1 for st in getattr(self._sim, "stations", {}).values()
            if getattr(st, "is_broken", False)
        )
        self._last_lunch_state = getattr(self._sim, "_lunch_active", False)

        try:
            saved = self._sim.save_state()
        except Exception as e:
            logger.warning("Oracle save_state failed: %s", e)
            return

        fork_end = now + self.HORIZON
        scores: Dict[str, float] = {}
        raw: Dict[str, Tuple[float, float, float]] = {}

        for heur in self.HEURISTIC_NAMES:
            try:
                heur_fn = DISPATCH_MAP[heur]
                fork = WarehouseSimulator.from_state(saved, heur_fn)
                fork.step_to(fork_end)
                m = fork.get_partial_metrics(since_time=now)
                tard = float(m.total_tardiness) if np.isfinite(m.total_tardiness) else 1e9
                sla  = float(m.sla_breach_rate) if np.isfinite(m.sla_breach_rate) else 1.0
                cyc  = float(m.avg_cycle_time) if np.isfinite(m.avg_cycle_time) else 1e6
            except Exception as e:
                logger.warning("Fork for %s failed at t=%.1f: %s", heur, now, e)
                tard, sla, cyc = 1e9, 1.0, 1e6
            raw[heur] = (tard, sla, cyc)

        # Normalize across heuristics so units are comparable, then composite score
        tards = np.array([raw[h][0] for h in self.HEURISTIC_NAMES])
        slas  = np.array([raw[h][1] for h in self.HEURISTIC_NAMES])
        cycs  = np.array([raw[h][2] for h in self.HEURISTIC_NAMES])

        def _norm(a: np.ndarray) -> np.ndarray:
            lo, hi = float(a.min()), float(a.max())
            if hi - lo < 1e-10:
                return np.zeros_like(a)
            return (a - lo) / (hi - lo)

        n_t = _norm(tards); n_s = _norm(slas); n_c = _norm(cycs)
        composite = self.W_TARD * n_t + self.W_SLA * n_s + self.W_CYC * n_c
        for i, h in enumerate(self.HEURISTIC_NAMES):
            scores[h] = float(composite[i])

        # Optional ML prior for tie-breaking (Hybrid mode). Does NOT override
        # oracle-chosen winner; only nudges among near-ties.
        ml_probs: Dict[str, float] = {}
        if self._ml_model is not None and self._fe is not None:
            try:
                sim_state = self._sim.get_state_snapshot()
                feats = self._fe.extract_scenario_features(sim_state)
                probs = self._ml_model.predict_proba(feats.reshape(1, -1))[0]
                for i, h in enumerate(self.HEURISTIC_NAMES):
                    if i < len(probs):
                        ml_probs[h] = float(probs[i])
            except Exception as e:
                logger.debug("ML prior failed (non-fatal): %s", e)

        # Pick best oracle score; break ties (within 2%) by highest ML probability
        sorted_h = sorted(self.HEURISTIC_NAMES, key=lambda h: scores[h])
        best = sorted_h[0]
        best_score = scores[best]
        if ml_probs:
            tied = [h for h in sorted_h if scores[h] - best_score < 0.02]
            if len(tied) > 1:
                best = max(tied, key=lambda h: ml_probs.get(h, 0.0))

        switched = best != self._current_heuristic
        self.switching_log.record(
            time=now,
            features=[float(raw[h][0]) for h in self.HEURISTIC_NAMES],
            probabilities={h: round(scores[h], 4) for h in self.HEURISTIC_NAMES},
            selected=best,
            switched=switched,
            reason="oracle_fork" if not ml_probs else "hybrid_oracle",
            confidence=1.0 - best_score,  # lower composite β†’ higher confidence
            top_features=[
                {"name": f"oracle_tard_{h}", "value": round(raw[h][0], 2), "importance": 1.0}
                for h in self.HEURISTIC_NAMES
            ],
            plain_english=(
                f"Oracle fork: {best} wins next {int(self.HORIZON)}-min horizon "
                f"(composite score {best_score:.3f})."
            ),
        )
        self._current_heuristic = best

    def _apply_starvation_prevention(self, jobs: List[Any], current_time: float) -> List[Any]:
        starving = [j for j in jobs if (current_time - j.arrival_time) > self.STARVATION_LIMIT]
        non_starving = [j for j in jobs if j not in starving]
        return starving + non_starving


# ---------------------------------------------------------------------------
# Factory helpers
# ---------------------------------------------------------------------------

def load_batchwise_selector(

    model_name: str = "rf",

    feature_extractor: Any = None,

) -> BatchwiseSelector:
    """Load a BatchwiseSelector for a given classifier variant.



    Parameters

    ----------

    model_name : str

        One of "dt", "rf", "xgb".

    feature_extractor : FeatureExtractor

        Feature extraction instance.

    """
    import json

    if feature_extractor is None:
        from src.features import FeatureExtractor
        feature_extractor = FeatureExtractor()

    path = MODELS_DIR / f"selector_{model_name}.joblib"
    if not path.exists():
        raise FileNotFoundError(f"Model not found: {path}")
    model = joblib.load(path)

    model_hash = getattr(model, "_dahs_run_hash", None)

    # Load feature importances if available
    feature_importances = None
    feature_names = None
    names_meta: Dict[str, Any] = {}

    try:
        feature_names_path = MODELS_DIR / "feature_names.json"
        if feature_names_path.exists():
            with open(feature_names_path) as f:
                names_data = json.load(f)
            if isinstance(names_data, dict) and "features" in names_data:
                names_meta = names_data.get("_meta", {})
                feature_names = [d["name"] for d in names_data["features"]]
            else:
                feature_names = [d["name"] for d in names_data]

        if hasattr(model, "feature_importances_"):
            feature_importances = model.feature_importances_
    except Exception as exc:
        logger.warning("Failed to load feature_names.json: %s", exc)

    # Load feature ranges for OOD detection
    ranges_meta: Dict[str, Any] = {}
    try:
        ranges_path = MODELS_DIR / "feature_ranges.json"
        if ranges_path.exists():
            feature_extractor.load_feature_ranges(ranges_path)
            ranges_meta = getattr(feature_extractor, "_feature_ranges_meta", {}) or {}
    except Exception as exc:
        logger.warning("Failed to load feature_ranges.json: %s", exc)

    # Validate that all artifacts came from the same training run. Legacy
    # artifacts (model_hash is None) are tolerated for backwards compatibility,
    # but any present-and-disagreeing hashes raise loudly β€” a mismatch means
    # someone retrained without regenerating sidecars and the OOD guardrail
    # would otherwise apply stale ranges.
    artifact_hashes = {
        "model": model_hash,
        "feature_ranges": ranges_meta.get("run_hash"),
        "feature_names": names_meta.get("run_hash"),
    }
    present = {k: v for k, v in artifact_hashes.items() if v is not None}
    if len(set(present.values())) > 1:
        raise RuntimeError(
            "DAHS model/artifact hash mismatch β€” re-run scripts/run_pipeline.py "
            f"to regenerate them in lockstep. Hashes: {artifact_hashes}"
        )
    if feature_names is not None and hasattr(model, "n_features_in_"):
        if model.n_features_in_ != len(feature_names):
            raise RuntimeError(
                f"Model expects {model.n_features_in_} features but "
                f"feature_names.json has {len(feature_names)}. Retrain."
            )

    return BatchwiseSelector(
        model=model,
        feature_extractor=feature_extractor,
        feature_importances=feature_importances,
        feature_names=feature_names,
    )


def load_hybrid_priority(feature_extractor: Any = None) -> HybridPriority:
    """Load the GBR-based HybridPriority scheduler."""
    if feature_extractor is None:
        from src.features import FeatureExtractor
        feature_extractor = FeatureExtractor()
    path = MODELS_DIR / "priority_gbr.joblib"
    return HybridPriority(model_path=path, feature_extractor=feature_extractor)