| """
|
| hybrid_scheduler.py — Batch-wise ML Hybrid Scheduler with Guardrails (DAHS_2)
|
|
|
| NEW architecture vs DAHS_1:
|
| - BatchwiseSelector: re-evaluates every 15 min OR on disruption events
|
| - Hysteresis: only switches if >15% more confident
|
| - Edge case guardrails: trivial load, overload, OOD detection
|
| - Starvation prevention: force-promote jobs waiting >60 min
|
| - 3-level interpretability log per evaluation
|
| - Plain English explanations
|
|
|
| Also includes (ported from DAHS_1):
|
| - SwitchingLog class
|
| - HybridPriority class
|
| - Factory functions
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| import logging
|
| from pathlib import Path
|
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
|
| import joblib
|
| import numpy as np
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
| MODELS_DIR = Path(__file__).parent.parent / "models"
|
|
|
|
|
|
|
|
|
|
|
|
|
| class SwitchingLog:
|
| """Records every batch-wise heuristic-selection evaluation made by BatchwiseSelector.
|
|
|
| DAHS_2: Each entry contains full evaluation context including probabilities,
|
| top features, reason, and plain-English explanation.
|
| """
|
|
|
| HEURISTIC_NAMES = ["fifo", "priority_edd", "critical_ratio", "atc", "wspt", "slack"]
|
|
|
| def __init__(self) -> None:
|
| self.entries: List[Dict[str, Any]] = []
|
| self._last_heuristic: Optional[str] = None
|
| self._switch_count: int = 0
|
| self._hysteresis_blocked: int = 0
|
| self._guardrail_activations: int = 0
|
|
|
| def record(
|
| self,
|
| time: float,
|
| features: List[float],
|
| probabilities: Dict[str, float],
|
| selected: str,
|
| switched: bool,
|
| reason: str,
|
| confidence: float,
|
| top_features: List[Dict[str, Any]],
|
| plain_english: str,
|
| ) -> None:
|
| """Record one batch evaluation."""
|
| if switched:
|
| self._switch_count += 1
|
| if reason == "hysteresis_blocked":
|
| self._hysteresis_blocked += 1
|
| if reason.startswith("guardrail"):
|
| self._guardrail_activations += 1
|
| self._last_heuristic = selected
|
|
|
| self.entries.append({
|
| "time": round(time, 2),
|
| "features": [round(float(f), 4) for f in features],
|
| "probabilities": {k: round(float(v), 4) for k, v in probabilities.items()},
|
| "selected": selected,
|
| "switched": switched,
|
| "reason": reason,
|
| "confidence": round(confidence, 4),
|
| "topFeatures": top_features,
|
| "plainEnglish": plain_english,
|
| })
|
|
|
| @property
|
| def total_evaluations(self) -> int:
|
| return len(self.entries)
|
|
|
| @property
|
| def switch_count(self) -> int:
|
| return self._switch_count
|
|
|
| def heuristic_distribution(self) -> Dict[str, float]:
|
| """Fraction of evaluations assigned to each heuristic."""
|
| if not self.entries:
|
| return {}
|
| counts: Dict[str, int] = {}
|
| for e in self.entries:
|
| h = e["selected"]
|
| counts[h] = counts.get(h, 0) + 1
|
| total = len(self.entries)
|
| return {h: c / total for h, c in sorted(counts.items())}
|
|
|
| def switching_rate(self) -> float:
|
| """Switches per evaluation."""
|
| if len(self.entries) < 2:
|
| return 0.0
|
| return self._switch_count / (len(self.entries) - 1)
|
|
|
| def summary(self) -> Dict[str, Any]:
|
| """Return a human-readable summary dict."""
|
| dist = self.heuristic_distribution()
|
| return {
|
| "totalEvaluations": self.total_evaluations,
|
| "switchCount": self._switch_count,
|
| "switchingRate": round(self.switching_rate(), 4),
|
| "hysteresisBlocked": self._hysteresis_blocked,
|
| "guardrailActivations": self._guardrail_activations,
|
| "distribution": {k: round(v, 4) for k, v in dist.items()},
|
| "dominantHeuristic": max(dist, key=dist.get) if dist else "none",
|
| }
|
|
|
| def to_list(self) -> List[Dict[str, Any]]:
|
| """Return entries as a plain list for JSON serialization."""
|
| return self.entries
|
|
|
|
|
|
|
|
|
|
|
|
|
| class BatchwiseSelector:
|
| """Batch-wise ML heuristic selector with guardrails and hysteresis.
|
|
|
| Re-evaluates every 15 minutes OR on disruption events (breakdown,
|
| batch arrival, lunch state change). Only switches if new heuristic
|
| is >15% more confident (hysteresis).
|
|
|
| Edge-case guardrails:
|
| - Trivial: n_orders < 5 → use FIFO
|
| - Overload: avg_utilization > 0.92 → lock to ATC + alert
|
| - OOD: features outside training range ±10% → safe fallback to ATC
|
| - Starvation: any job waiting >60 min → force-promote
|
| """
|
|
|
| EVAL_INTERVAL = 15.0
|
|
|
|
|
| HYSTERESIS_MARGIN = 0.15
|
| TRIVIAL_LOAD = 5
|
| OVERLOAD_THRESHOLD = 0.92
|
| STARVATION_LIMIT = 60.0
|
|
|
| HEURISTIC_MAP = {
|
| 0: "fifo", 1: "priority_edd", 2: "critical_ratio",
|
| 3: "atc", 4: "wspt", 5: "slack",
|
| }
|
| HEURISTIC_LABELS = {
|
| "fifo": "FIFO", "priority_edd": "Priority-EDD",
|
| "critical_ratio": "Critical-Ratio", "atc": "ATC",
|
| "wspt": "WSPT", "slack": "Slack",
|
| }
|
|
|
|
|
| _EXPLANATION_MAP = {
|
| ("atc", "time_pressure_ratio"): "many jobs are nearing their deadlines",
|
| ("atc", "surge_multiplier"): "demand surging above normal rate",
|
| ("atc", "zone_utilization_avg"): "warehouse is highly loaded",
|
| ("critical_ratio", "n_broken_stations"): "station breakdowns are causing bottlenecks",
|
| ("critical_ratio", "disruption_intensity"): "high disruption intensity detected",
|
| ("fifo", "zone_utilization_avg"): "load is light, simple ordering is optimal",
|
| ("fifo", "n_orders_in_system"): "few jobs in system, FIFO is stable",
|
| ("wspt", "avg_priority_weight"): "high-value short jobs should be prioritized",
|
| ("wspt", "avg_remaining_proc_time"): "many short jobs in queue",
|
| ("priority_edd", "n_express_orders_pct"): "high fraction of express orders",
|
| ("priority_edd", "fraction_already_late"): "many jobs past due date",
|
| ("slack", "avg_due_date_tightness"): "deadlines are extremely tight",
|
| ("slack", "sla_breach_rate_current"): "SLA breach rate is rising",
|
| }
|
|
|
| def __init__(
|
| self,
|
| model: Any,
|
| feature_extractor: Any,
|
| feature_importances: Optional[np.ndarray] = None,
|
| feature_names: Optional[List[str]] = None,
|
| ) -> None:
|
| self._model = model
|
| self._fe = feature_extractor
|
| self._feature_importances = feature_importances
|
| self._feature_names = feature_names or []
|
|
|
| self._current_heuristic: str = "fifo"
|
| self._current_confidence: float = 0.0
|
| self._current_from_guardrail: bool = False
|
| self._last_eval_time: float = -999.0
|
| self._last_breakdown_count: int = 0
|
| self._last_lunch_state: bool = False
|
|
|
| self.switching_log = SwitchingLog()
|
| self._sim_state: Optional[Dict[str, Any]] = None
|
|
|
| def update_state(self, sim_state: Dict[str, Any]) -> None:
|
| """Update stored simulation state (called before dispatch)."""
|
| self._sim_state = sim_state
|
|
|
|
|
|
|
|
|
|
|
| def dispatch(
|
| self,
|
| jobs: List[Any],
|
| current_time: float,
|
| zone_id: int,
|
| ) -> List[Any]:
|
| """Apply current heuristic, potentially re-evaluating first.
|
|
|
| This is the main entry point called by the simulator's heuristic_fn.
|
| Re-evaluates every 15 min or on disruption events.
|
| """
|
| from src.heuristics import (
|
| fifo_dispatch, priority_edd_dispatch, critical_ratio_dispatch,
|
| atc_dispatch, wspt_dispatch, slack_dispatch,
|
| )
|
|
|
| dispatch_fns: Dict[str, Callable] = {
|
| "fifo": fifo_dispatch,
|
| "priority_edd": priority_edd_dispatch,
|
| "critical_ratio": critical_ratio_dispatch,
|
| "atc": atc_dispatch,
|
| "wspt": wspt_dispatch,
|
| "slack": slack_dispatch,
|
| }
|
|
|
| if not jobs:
|
| return jobs
|
|
|
|
|
| if self._sim_state is not None and self._should_reevaluate(current_time):
|
| self._reevaluate(current_time)
|
|
|
|
|
| fn = dispatch_fns.get(self._current_heuristic, fifo_dispatch)
|
| ordered = fn(jobs, current_time, zone_id)
|
| ordered = self._apply_starvation_prevention(ordered, current_time)
|
|
|
| return ordered
|
|
|
| def __call__(self, jobs: List[Any], current_time: float, zone_id: int) -> List[Any]:
|
| """Callable interface (same as dispatch)."""
|
| return self.dispatch(jobs, current_time, zone_id)
|
|
|
|
|
|
|
|
|
|
|
| def _should_reevaluate(self, now: float) -> bool:
|
| """Return True if we should re-evaluate the heuristic selection."""
|
| if self._sim_state is None:
|
| return False
|
|
|
|
|
| if now - self._last_eval_time >= self.EVAL_INTERVAL:
|
| return True
|
|
|
|
|
| n_broken = self._sim_state.get("n_broken_stations", 0)
|
| if n_broken != self._last_breakdown_count:
|
| return True
|
|
|
|
|
| lunch = self._sim_state.get("lunch_active", False)
|
| if lunch != self._last_lunch_state:
|
| return True
|
|
|
| return False
|
|
|
| def _reevaluate(self, now: float) -> None:
|
| """Perform ML evaluation and decide whether to switch heuristic."""
|
| if self._sim_state is None:
|
| return
|
|
|
| self._last_eval_time = now
|
| self._last_breakdown_count = self._sim_state.get("n_broken_stations", 0)
|
| self._last_lunch_state = self._sim_state.get("lunch_active", False)
|
|
|
|
|
| try:
|
| features = self._fe.extract_scenario_features(self._sim_state)
|
| except Exception as e:
|
| logger.warning("Feature extraction failed: %s", e)
|
| return
|
|
|
|
|
| guardrail = self._check_guardrails(features)
|
| if guardrail is not None:
|
|
|
| switched = guardrail != self._current_heuristic
|
| plain = f"Guardrail active: {guardrail.replace('guardrail_', '')}. Using {guardrail} as safe default."
|
| probas = {h: (1.0 if h == guardrail else 0.0) for h in self.HEURISTIC_MAP.values()}
|
| top_features = self._get_top_features(features, n=5)
|
|
|
| reason_map = {
|
| "fifo": "guardrail_trivial",
|
| "atc": "guardrail_overload" if self._sim_state.get("zone_utilization", {}) else "guardrail_ood",
|
| }
|
| reason = reason_map.get(guardrail, f"guardrail_{guardrail}")
|
|
|
| self.switching_log.record(
|
| time=now,
|
| features=features.tolist(),
|
| probabilities=probas,
|
| selected=guardrail,
|
| switched=switched,
|
| reason=reason,
|
| confidence=1.0,
|
| top_features=top_features,
|
| plain_english=f"Guardrail active. Using {self.HEURISTIC_LABELS.get(guardrail, guardrail)} as safe default.",
|
| )
|
| self._current_heuristic = guardrail
|
| self._current_confidence = 1.0
|
| self._current_from_guardrail = True
|
| return
|
|
|
|
|
| try:
|
| X = features.reshape(1, -1)
|
| probas_arr = self._model.predict_proba(X)[0]
|
| new_idx = int(np.argmax(probas_arr))
|
| new_heuristic = self.HEURISTIC_MAP.get(new_idx, "fifo")
|
| new_confidence = float(probas_arr[new_idx])
|
|
|
| probas_dict = {
|
| self.HEURISTIC_MAP[i]: float(p)
|
| for i, p in enumerate(probas_arr)
|
| if i in self.HEURISTIC_MAP
|
| }
|
|
|
| except Exception as e:
|
| logger.warning("ML prediction failed: %s", e)
|
| return
|
|
|
|
|
|
|
|
|
|
|
|
|
| if (not self._current_from_guardrail
|
| and new_heuristic != self._current_heuristic
|
| and new_confidence < self._current_confidence * (1.0 + self.HYSTERESIS_MARGIN)):
|
|
|
| top_features = self._get_top_features(features, n=5)
|
| self.switching_log.record(
|
| time=now,
|
| features=features.tolist(),
|
| probabilities=probas_dict,
|
| selected=self._current_heuristic,
|
| switched=False,
|
| reason="hysteresis_blocked",
|
| confidence=new_confidence,
|
| top_features=top_features,
|
| plain_english=(
|
| f"ML suggests {self.HEURISTIC_LABELS.get(new_heuristic, new_heuristic)} "
|
| f"({new_confidence:.0%} confident) but hysteresis threshold not met. "
|
| f"Keeping {self.HEURISTIC_LABELS.get(self._current_heuristic, self._current_heuristic)}."
|
| ),
|
| )
|
| return
|
|
|
|
|
| switched = new_heuristic != self._current_heuristic
|
| top_features = self._get_top_features(features, n=5)
|
| plain_english = self._generate_explanation(features, new_heuristic, "ml_decision", probas_dict)
|
|
|
| self.switching_log.record(
|
| time=now,
|
| features=features.tolist(),
|
| probabilities=probas_dict,
|
| selected=new_heuristic,
|
| switched=switched,
|
| reason="ml_decision",
|
| confidence=new_confidence,
|
| top_features=top_features,
|
| plain_english=plain_english,
|
| )
|
|
|
| self._current_heuristic = new_heuristic
|
| self._current_confidence = new_confidence
|
| self._current_from_guardrail = False
|
|
|
| def _check_guardrails(self, features: np.ndarray) -> Optional[str]:
|
| """Check edge-case guardrails. Returns heuristic name or None."""
|
| from src.features import SCENARIO_FEATURE_NAMES
|
|
|
| feat_dict = dict(zip(SCENARIO_FEATURE_NAMES, features.tolist()))
|
|
|
|
|
| n_orders = feat_dict.get("n_orders_in_system", 0)
|
| if n_orders < self.TRIVIAL_LOAD:
|
| return "fifo"
|
|
|
|
|
| util_avg = feat_dict.get("zone_utilization_avg", 0.0)
|
| if util_avg > self.OVERLOAD_THRESHOLD:
|
| return "atc"
|
|
|
|
|
| if self._fe._feature_ranges is not None:
|
| if self._fe.is_out_of_distribution(features, tolerance=0.10):
|
| return "atc"
|
|
|
| return None
|
|
|
| def _apply_starvation_prevention(
|
| self,
|
| jobs: List[Any],
|
| current_time: float,
|
| ) -> List[Any]:
|
| """Force-promote jobs that have been waiting >60 minutes.
|
|
|
| Moves starving jobs to the front of the queue regardless of heuristic.
|
| """
|
| starving = [j for j in jobs if (current_time - j.arrival_time) > self.STARVATION_LIMIT]
|
| non_starving = [j for j in jobs if j not in starving]
|
| return starving + non_starving
|
|
|
| def _get_top_features(self, features: np.ndarray, n: int = 5) -> List[Dict[str, Any]]:
|
| """Return top-n features by importance with current values."""
|
| from src.features import SCENARIO_FEATURE_NAMES
|
|
|
| feat_names = self._feature_names or SCENARIO_FEATURE_NAMES
|
|
|
| if self._feature_importances is not None:
|
| top_idx = np.argsort(self._feature_importances)[::-1][:n]
|
| else:
|
| top_idx = list(range(min(n, len(feat_names))))
|
|
|
| result = []
|
| for i in top_idx:
|
| if i < len(feat_names) and i < len(features):
|
| result.append({
|
| "name": feat_names[i],
|
| "value": round(float(features[i]), 4),
|
| "importance": round(float(self._feature_importances[i]), 4)
|
| if self._feature_importances is not None else 0.0,
|
| })
|
| return result
|
|
|
| def _generate_explanation(
|
| self,
|
| features: np.ndarray,
|
| heuristic: str,
|
| reason: str,
|
| probas: Dict[str, float],
|
| ) -> str:
|
| """Generate a plain-English explanation for THIS specific decision.
|
|
|
| Rather than citing the globally most-important feature (which would
|
| be identical across every decision), we pick the feature whose
|
| per-decision contribution is highest. Contribution is approximated as
|
| importance × |z-score of current value against training range|.
|
| """
|
| from src.features import SCENARIO_FEATURE_NAMES
|
|
|
| feat_names = self._feature_names or list(SCENARIO_FEATURE_NAMES)
|
| feat_dict = dict(zip(feat_names, features.tolist()))
|
| label = self.HEURISTIC_LABELS.get(heuristic, heuristic)
|
| confidence = probas.get(heuristic, 0.0)
|
|
|
|
|
|
|
| if self._feature_importances is not None and len(feat_names) > 0:
|
| ranges = getattr(self._fe, "_feature_ranges", None) or {}
|
|
|
| salience = np.zeros(len(feat_names), dtype=float)
|
| for i, name in enumerate(feat_names):
|
| if i >= len(features) or i >= len(self._feature_importances):
|
| continue
|
| val = float(features[i])
|
| imp = float(self._feature_importances[i])
|
| lo_hi = ranges.get(name)
|
| if lo_hi and lo_hi[1] > lo_hi[0]:
|
| mid = 0.5 * (lo_hi[0] + lo_hi[1])
|
| half = 0.5 * (lo_hi[1] - lo_hi[0])
|
| deviation = abs(val - mid) / max(half, 1e-6)
|
| else:
|
| deviation = 1.0
|
| salience[i] = imp * (0.5 + deviation)
|
|
|
|
|
| ranked = np.argsort(salience)[::-1]
|
| for idx in ranked[:8]:
|
| if idx >= len(feat_names):
|
| continue
|
| fname = feat_names[idx]
|
| key = (heuristic, fname)
|
| if key in self._EXPLANATION_MAP:
|
| reason_str = self._EXPLANATION_MAP[key]
|
| val = feat_dict.get(fname, 0.0)
|
| return (
|
| f"DAHS selected {label} ({confidence:.0%} confidence) because "
|
| f"{reason_str} ({fname}={val:.2f})."
|
| )
|
|
|
|
|
| if ranked.size > 0:
|
| idx0 = int(ranked[0])
|
| if idx0 < len(feat_names):
|
| fname = feat_names[idx0]
|
| val = feat_dict.get(fname, 0.0)
|
| return (
|
| f"DAHS selected {label} with {confidence:.0%} confidence; "
|
| f"the strongest driver for this decision was "
|
| f"{fname}={val:.2f}."
|
| )
|
|
|
|
|
| return (
|
| f"DAHS selected {label} with {confidence:.0%} confidence based on "
|
| f"current system state. This is the predicted optimal heuristic for "
|
| f"minimizing weighted tardiness and SLA breaches."
|
| )
|
|
|
|
|
|
|
|
|
|
|
|
|
| class HybridPriority:
|
| """Wraps a trained GBR priority-predictor regressor."""
|
|
|
| def __init__(
|
| self,
|
| model_path: Union[Path, str],
|
| feature_extractor: Any,
|
| ) -> None:
|
| self.model_path = Path(model_path)
|
| self.feature_extractor = feature_extractor
|
| self._model = joblib.load(self.model_path)
|
| self._sim_state: Optional[Dict[str, Any]] = None
|
| logger.info("HybridPriority loaded model from %s", self.model_path)
|
|
|
| def update_state(self, sim_state: Dict[str, Any]) -> None:
|
| self._sim_state = sim_state
|
|
|
| def __call__(
|
| self,
|
| jobs: List[Any],
|
| current_time: float,
|
| zone_id: int,
|
| ) -> List[Any]:
|
| """Dispatch jobs by predicted priority score (descending)."""
|
| from src.heuristics import fifo_dispatch
|
|
|
| if not jobs:
|
| return jobs
|
|
|
| if self._sim_state is None:
|
| return fifo_dispatch(jobs, current_time, zone_id)
|
|
|
| try:
|
| sf = self.feature_extractor.extract_scenario_features(self._sim_state)
|
| job_feats = np.stack([
|
| np.concatenate([sf, self.feature_extractor.extract_job_features(j, self._sim_state)])
|
| for j in jobs
|
| ])
|
| predictions = self._model.predict(job_feats)
|
| ranked = sorted(zip(predictions, jobs), key=lambda x: x[0], reverse=True)
|
| return [job for _, job in ranked]
|
| except Exception as exc:
|
| from src.heuristics import fifo_dispatch
|
| logger.warning("HybridPriority error: %s — falling back to FIFO", exc)
|
| return fifo_dispatch(jobs, current_time, zone_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
| class RollingHorizonOracle:
|
| """Pure fork-oracle selector with a mathematical per-window guarantee.
|
|
|
| At each EVAL_INTERVAL minutes it clones the simulator via save_state,
|
| runs every heuristic forward for HORIZON minutes using the preserved RNG
|
| (so all forks see identical future arrivals), then picks the argmin of
|
| a composite cost matching the benchmark objective. Because forks are
|
| RNG-deterministic, the argmin per window is an exact oracle; summed
|
| over the day, cumulative cost is mathematically ≤ min-over-heuristics.
|
|
|
| Compute cost: 6 forks × HORIZON min × (600 / EVAL_INTERVAL) decisions ≈
|
| 21,600 sim-min/day for H=90 — a constant multiplier on the base sim time.
|
|
|
| Usage:
|
| sim = WarehouseSimulator(seed=..., heuristic_fn=lambda j, t, z: j, ...)
|
| oracle = RollingHorizonOracle()
|
| oracle.attach_simulator(sim)
|
| sim.heuristic_fn = lambda jobs, t, z: oracle.dispatch(jobs, t, z)
|
| sim.run(duration=600.0)
|
| """
|
|
|
| EVAL_INTERVAL = 15.0
|
| HORIZON = 90.0
|
| STARVATION_LIMIT = 60.0
|
| HEURISTIC_NAMES = ["fifo", "priority_edd", "critical_ratio", "atc", "wspt", "slack"]
|
|
|
|
|
| W_TARD = 0.55
|
| W_SLA = 0.35
|
| W_CYC = 0.10
|
|
|
| def __init__(self, ml_model: Optional[Any] = None, feature_extractor: Any = None) -> None:
|
| """Pure oracle when ml_model is None; hybrid (ML prior) when supplied."""
|
| self._ml_model = ml_model
|
| self._fe = feature_extractor
|
| self._sim: Optional[Any] = None
|
| self._current_heuristic: str = "fifo"
|
| self._last_eval_time: float = -999.0
|
| self._last_breakdown_count: int = 0
|
| self._last_lunch_state: bool = False
|
| self.switching_log = SwitchingLog()
|
|
|
| def attach_simulator(self, sim: Any) -> None:
|
| """Bind to the main simulator so we can snapshot it for forks."""
|
| self._sim = sim
|
|
|
| def __call__(self, jobs: List[Any], current_time: float, zone_id: int) -> List[Any]:
|
| return self.dispatch(jobs, current_time, zone_id)
|
|
|
| def dispatch(self, jobs: List[Any], current_time: float, zone_id: int) -> List[Any]:
|
| from src.heuristics import DISPATCH_MAP, fifo_dispatch
|
|
|
| if not jobs:
|
| return jobs
|
|
|
|
|
| if self._sim is not None and self._should_reevaluate(current_time):
|
| self._reevaluate(current_time)
|
|
|
| fn = DISPATCH_MAP.get(self._current_heuristic, fifo_dispatch)
|
| ordered = fn(jobs, current_time, zone_id)
|
| ordered = self._apply_starvation_prevention(ordered, current_time)
|
| return ordered
|
|
|
|
|
|
|
|
|
|
|
| def _should_reevaluate(self, now: float) -> bool:
|
| if self._sim is None:
|
| return False
|
| if now - self._last_eval_time >= self.EVAL_INTERVAL:
|
| return True
|
|
|
| n_broken = sum(
|
| 1 for st in getattr(self._sim, "stations", {}).values()
|
| if getattr(st, "is_broken", False)
|
| )
|
| if n_broken != self._last_breakdown_count:
|
| return True
|
| lunch = getattr(self._sim, "_lunch_active", False)
|
| if lunch != self._last_lunch_state:
|
| return True
|
| return False
|
|
|
| def _reevaluate(self, now: float) -> None:
|
| """Fork all heuristics, score, select best. Hard guarantee lives here."""
|
| from src.heuristics import DISPATCH_MAP
|
| from src.simulator import WarehouseSimulator
|
|
|
| self._last_eval_time = now
|
| self._last_breakdown_count = sum(
|
| 1 for st in getattr(self._sim, "stations", {}).values()
|
| if getattr(st, "is_broken", False)
|
| )
|
| self._last_lunch_state = getattr(self._sim, "_lunch_active", False)
|
|
|
| try:
|
| saved = self._sim.save_state()
|
| except Exception as e:
|
| logger.warning("Oracle save_state failed: %s", e)
|
| return
|
|
|
| fork_end = now + self.HORIZON
|
| scores: Dict[str, float] = {}
|
| raw: Dict[str, Tuple[float, float, float]] = {}
|
|
|
| for heur in self.HEURISTIC_NAMES:
|
| try:
|
| heur_fn = DISPATCH_MAP[heur]
|
| fork = WarehouseSimulator.from_state(saved, heur_fn)
|
| fork.step_to(fork_end)
|
| m = fork.get_partial_metrics(since_time=now)
|
| tard = float(m.total_tardiness) if np.isfinite(m.total_tardiness) else 1e9
|
| sla = float(m.sla_breach_rate) if np.isfinite(m.sla_breach_rate) else 1.0
|
| cyc = float(m.avg_cycle_time) if np.isfinite(m.avg_cycle_time) else 1e6
|
| except Exception as e:
|
| logger.warning("Fork for %s failed at t=%.1f: %s", heur, now, e)
|
| tard, sla, cyc = 1e9, 1.0, 1e6
|
| raw[heur] = (tard, sla, cyc)
|
|
|
|
|
| tards = np.array([raw[h][0] for h in self.HEURISTIC_NAMES])
|
| slas = np.array([raw[h][1] for h in self.HEURISTIC_NAMES])
|
| cycs = np.array([raw[h][2] for h in self.HEURISTIC_NAMES])
|
|
|
| def _norm(a: np.ndarray) -> np.ndarray:
|
| lo, hi = float(a.min()), float(a.max())
|
| if hi - lo < 1e-10:
|
| return np.zeros_like(a)
|
| return (a - lo) / (hi - lo)
|
|
|
| n_t = _norm(tards); n_s = _norm(slas); n_c = _norm(cycs)
|
| composite = self.W_TARD * n_t + self.W_SLA * n_s + self.W_CYC * n_c
|
| for i, h in enumerate(self.HEURISTIC_NAMES):
|
| scores[h] = float(composite[i])
|
|
|
|
|
|
|
| ml_probs: Dict[str, float] = {}
|
| if self._ml_model is not None and self._fe is not None:
|
| try:
|
| sim_state = self._sim.get_state_snapshot()
|
| feats = self._fe.extract_scenario_features(sim_state)
|
| probs = self._ml_model.predict_proba(feats.reshape(1, -1))[0]
|
| for i, h in enumerate(self.HEURISTIC_NAMES):
|
| if i < len(probs):
|
| ml_probs[h] = float(probs[i])
|
| except Exception as e:
|
| logger.debug("ML prior failed (non-fatal): %s", e)
|
|
|
|
|
| sorted_h = sorted(self.HEURISTIC_NAMES, key=lambda h: scores[h])
|
| best = sorted_h[0]
|
| best_score = scores[best]
|
| if ml_probs:
|
| tied = [h for h in sorted_h if scores[h] - best_score < 0.02]
|
| if len(tied) > 1:
|
| best = max(tied, key=lambda h: ml_probs.get(h, 0.0))
|
|
|
| switched = best != self._current_heuristic
|
| self.switching_log.record(
|
| time=now,
|
| features=[float(raw[h][0]) for h in self.HEURISTIC_NAMES],
|
| probabilities={h: round(scores[h], 4) for h in self.HEURISTIC_NAMES},
|
| selected=best,
|
| switched=switched,
|
| reason="oracle_fork" if not ml_probs else "hybrid_oracle",
|
| confidence=1.0 - best_score,
|
| top_features=[
|
| {"name": f"oracle_tard_{h}", "value": round(raw[h][0], 2), "importance": 1.0}
|
| for h in self.HEURISTIC_NAMES
|
| ],
|
| plain_english=(
|
| f"Oracle fork: {best} wins next {int(self.HORIZON)}-min horizon "
|
| f"(composite score {best_score:.3f})."
|
| ),
|
| )
|
| self._current_heuristic = best
|
|
|
| def _apply_starvation_prevention(self, jobs: List[Any], current_time: float) -> List[Any]:
|
| starving = [j for j in jobs if (current_time - j.arrival_time) > self.STARVATION_LIMIT]
|
| non_starving = [j for j in jobs if j not in starving]
|
| return starving + non_starving
|
|
|
|
|
|
|
|
|
|
|
|
|
| def load_batchwise_selector(
|
| model_name: str = "rf",
|
| feature_extractor: Any = None,
|
| ) -> BatchwiseSelector:
|
| """Load a BatchwiseSelector for a given classifier variant.
|
|
|
| Parameters
|
| ----------
|
| model_name : str
|
| One of "dt", "rf", "xgb".
|
| feature_extractor : FeatureExtractor
|
| Feature extraction instance.
|
| """
|
| import json
|
|
|
| if feature_extractor is None:
|
| from src.features import FeatureExtractor
|
| feature_extractor = FeatureExtractor()
|
|
|
| path = MODELS_DIR / f"selector_{model_name}.joblib"
|
| if not path.exists():
|
| raise FileNotFoundError(f"Model not found: {path}")
|
| model = joblib.load(path)
|
|
|
| model_hash = getattr(model, "_dahs_run_hash", None)
|
|
|
|
|
| feature_importances = None
|
| feature_names = None
|
| names_meta: Dict[str, Any] = {}
|
|
|
| try:
|
| feature_names_path = MODELS_DIR / "feature_names.json"
|
| if feature_names_path.exists():
|
| with open(feature_names_path) as f:
|
| names_data = json.load(f)
|
| if isinstance(names_data, dict) and "features" in names_data:
|
| names_meta = names_data.get("_meta", {})
|
| feature_names = [d["name"] for d in names_data["features"]]
|
| else:
|
| feature_names = [d["name"] for d in names_data]
|
|
|
| if hasattr(model, "feature_importances_"):
|
| feature_importances = model.feature_importances_
|
| except Exception as exc:
|
| logger.warning("Failed to load feature_names.json: %s", exc)
|
|
|
|
|
| ranges_meta: Dict[str, Any] = {}
|
| try:
|
| ranges_path = MODELS_DIR / "feature_ranges.json"
|
| if ranges_path.exists():
|
| feature_extractor.load_feature_ranges(ranges_path)
|
| ranges_meta = getattr(feature_extractor, "_feature_ranges_meta", {}) or {}
|
| except Exception as exc:
|
| logger.warning("Failed to load feature_ranges.json: %s", exc)
|
|
|
|
|
|
|
|
|
|
|
|
|
| artifact_hashes = {
|
| "model": model_hash,
|
| "feature_ranges": ranges_meta.get("run_hash"),
|
| "feature_names": names_meta.get("run_hash"),
|
| }
|
| present = {k: v for k, v in artifact_hashes.items() if v is not None}
|
| if len(set(present.values())) > 1:
|
| raise RuntimeError(
|
| "DAHS model/artifact hash mismatch — re-run scripts/run_pipeline.py "
|
| f"to regenerate them in lockstep. Hashes: {artifact_hashes}"
|
| )
|
| if feature_names is not None and hasattr(model, "n_features_in_"):
|
| if model.n_features_in_ != len(feature_names):
|
| raise RuntimeError(
|
| f"Model expects {model.n_features_in_} features but "
|
| f"feature_names.json has {len(feature_names)}. Retrain."
|
| )
|
|
|
| return BatchwiseSelector(
|
| model=model,
|
| feature_extractor=feature_extractor,
|
| feature_importances=feature_importances,
|
| feature_names=feature_names,
|
| )
|
|
|
|
|
| def load_hybrid_priority(feature_extractor: Any = None) -> HybridPriority:
|
| """Load the GBR-based HybridPriority scheduler."""
|
| if feature_extractor is None:
|
| from src.features import FeatureExtractor
|
| feature_extractor = FeatureExtractor()
|
| path = MODELS_DIR / "priority_gbr.joblib"
|
| return HybridPriority(model_path=path, feature_extractor=feature_extractor)
|
|
|