Spaces:
Sleeping
Sleeping
| """ | |
| hybrid_scheduler.py — Batch-wise ML Hybrid Scheduler with Guardrails (DAHS_2) | |
| NEW architecture vs DAHS_1: | |
| - BatchwiseSelector: re-evaluates every 15 min OR on disruption events | |
| - Hysteresis: only switches if >15% more confident | |
| - Edge case guardrails: trivial load, overload, OOD detection | |
| - Starvation prevention: force-promote jobs waiting >60 min | |
| - 3-level interpretability log per evaluation | |
| - Plain English explanations | |
| Also includes (ported from DAHS_1): | |
| - SwitchingLog class | |
| - HybridPriority class | |
| - Factory functions | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from pathlib import Path | |
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |
| import joblib | |
| import numpy as np | |
| logger = logging.getLogger(__name__) | |
| MODELS_DIR = Path(__file__).parent.parent / "models" | |
| # --------------------------------------------------------------------------- | |
| # Switching Log (enhanced for DAHS_2 with evaluation payload) | |
| # --------------------------------------------------------------------------- | |
| class SwitchingLog: | |
| """Records every batch-wise heuristic-selection evaluation made by BatchwiseSelector. | |
| DAHS_2: Each entry contains full evaluation context including probabilities, | |
| top features, reason, and plain-English explanation. | |
| """ | |
| HEURISTIC_NAMES = ["fifo", "priority_edd", "critical_ratio", "atc", "wspt", "slack"] | |
| def __init__(self) -> None: | |
| self.entries: List[Dict[str, Any]] = [] | |
| self._last_heuristic: Optional[str] = None | |
| self._switch_count: int = 0 | |
| self._hysteresis_blocked: int = 0 | |
| self._guardrail_activations: int = 0 | |
| def record( | |
| self, | |
| time: float, | |
| features: List[float], | |
| probabilities: Dict[str, float], | |
| selected: str, | |
| switched: bool, | |
| reason: str, | |
| confidence: float, | |
| top_features: List[Dict[str, Any]], | |
| plain_english: str, | |
| ) -> None: | |
| """Record one batch evaluation.""" | |
| if switched: | |
| self._switch_count += 1 | |
| if reason == "hysteresis_blocked": | |
| self._hysteresis_blocked += 1 | |
| if reason.startswith("guardrail"): | |
| self._guardrail_activations += 1 | |
| self._last_heuristic = selected | |
| self.entries.append({ | |
| "time": round(time, 2), | |
| "features": [round(float(f), 4) for f in features], | |
| "probabilities": {k: round(float(v), 4) for k, v in probabilities.items()}, | |
| "selected": selected, | |
| "switched": switched, | |
| "reason": reason, | |
| "confidence": round(confidence, 4), | |
| "topFeatures": top_features, | |
| "plainEnglish": plain_english, | |
| }) | |
| def total_evaluations(self) -> int: | |
| return len(self.entries) | |
| def switch_count(self) -> int: | |
| return self._switch_count | |
| def heuristic_distribution(self) -> Dict[str, float]: | |
| """Fraction of evaluations assigned to each heuristic.""" | |
| if not self.entries: | |
| return {} | |
| counts: Dict[str, int] = {} | |
| for e in self.entries: | |
| h = e["selected"] | |
| counts[h] = counts.get(h, 0) + 1 | |
| total = len(self.entries) | |
| return {h: c / total for h, c in sorted(counts.items())} | |
| def switching_rate(self) -> float: | |
| """Switches per evaluation.""" | |
| if len(self.entries) < 2: | |
| return 0.0 | |
| return self._switch_count / (len(self.entries) - 1) | |
| def summary(self) -> Dict[str, Any]: | |
| """Return a human-readable summary dict.""" | |
| dist = self.heuristic_distribution() | |
| return { | |
| "totalEvaluations": self.total_evaluations, | |
| "switchCount": self._switch_count, | |
| "switchingRate": round(self.switching_rate(), 4), | |
| "hysteresisBlocked": self._hysteresis_blocked, | |
| "guardrailActivations": self._guardrail_activations, | |
| "distribution": {k: round(v, 4) for k, v in dist.items()}, | |
| "dominantHeuristic": max(dist, key=dist.get) if dist else "none", | |
| } | |
| def to_list(self) -> List[Dict[str, Any]]: | |
| """Return entries as a plain list for JSON serialization.""" | |
| return self.entries | |
| # --------------------------------------------------------------------------- | |
| # BatchwiseSelector — Core DAHS_2 scheduler | |
| # --------------------------------------------------------------------------- | |
| class BatchwiseSelector: | |
| """Batch-wise ML heuristic selector with guardrails and hysteresis. | |
| Re-evaluates every 15 minutes OR on disruption events (breakdown, | |
| batch arrival, lunch state change). Only switches if new heuristic | |
| is >15% more confident (hysteresis). | |
| Edge-case guardrails: | |
| - Trivial: n_orders < 5 → use FIFO | |
| - Overload: avg_utilization > 0.92 → lock to ATC + alert | |
| - OOD: features outside training range ±10% → safe fallback to ATC | |
| - Starvation: any job waiting >60 min → force-promote | |
| """ | |
| EVAL_INTERVAL = 15.0 # minutes between re-evaluations | |
| # Relative margin: new heuristic's probability must exceed current × (1 + margin). | |
| # Calibration-invariant across RF (broad) and XGB (sharp) predict_proba outputs. | |
| HYSTERESIS_MARGIN = 0.15 | |
| TRIVIAL_LOAD = 5 # skip ML if fewer jobs | |
| OVERLOAD_THRESHOLD = 0.92 # lock to ATC | |
| STARVATION_LIMIT = 60.0 # force-promote starving jobs (minutes) | |
| HEURISTIC_MAP = { | |
| 0: "fifo", 1: "priority_edd", 2: "critical_ratio", | |
| 3: "atc", 4: "wspt", 5: "slack", | |
| } | |
| HEURISTIC_LABELS = { | |
| "fifo": "FIFO", "priority_edd": "Priority-EDD", | |
| "critical_ratio": "Critical-Ratio", "atc": "ATC", | |
| "wspt": "WSPT", "slack": "Slack", | |
| } | |
| # Plain-English reason templates | |
| _EXPLANATION_MAP = { | |
| ("atc", "time_pressure_ratio"): "many jobs are nearing their deadlines", | |
| ("atc", "surge_multiplier"): "demand surging above normal rate", | |
| ("atc", "zone_utilization_avg"): "warehouse is highly loaded", | |
| ("critical_ratio", "n_broken_stations"): "station breakdowns are causing bottlenecks", | |
| ("critical_ratio", "disruption_intensity"): "high disruption intensity detected", | |
| ("fifo", "zone_utilization_avg"): "load is light, simple ordering is optimal", | |
| ("fifo", "n_orders_in_system"): "few jobs in system, FIFO is stable", | |
| ("wspt", "avg_priority_weight"): "high-value short jobs should be prioritized", | |
| ("wspt", "avg_remaining_proc_time"): "many short jobs in queue", | |
| ("priority_edd", "n_express_orders_pct"): "high fraction of express orders", | |
| ("priority_edd", "fraction_already_late"): "many jobs past due date", | |
| ("slack", "avg_due_date_tightness"): "deadlines are extremely tight", | |
| ("slack", "sla_breach_rate_current"): "SLA breach rate is rising", | |
| } | |
| def __init__( | |
| self, | |
| model: Any, | |
| feature_extractor: Any, | |
| feature_importances: Optional[np.ndarray] = None, | |
| feature_names: Optional[List[str]] = None, | |
| ) -> None: | |
| self._model = model | |
| self._fe = feature_extractor | |
| self._feature_importances = feature_importances | |
| self._feature_names = feature_names or [] | |
| self._current_heuristic: str = "fifo" | |
| self._current_confidence: float = 0.0 | |
| self._current_from_guardrail: bool = False | |
| self._last_eval_time: float = -999.0 | |
| self._last_breakdown_count: int = 0 | |
| self._last_lunch_state: bool = False | |
| self.switching_log = SwitchingLog() | |
| self._sim_state: Optional[Dict[str, Any]] = None | |
| def update_state(self, sim_state: Dict[str, Any]) -> None: | |
| """Update stored simulation state (called before dispatch).""" | |
| self._sim_state = sim_state | |
| # ------------------------------------------------------------------ | |
| # Main dispatch interface | |
| # ------------------------------------------------------------------ | |
| def dispatch( | |
| self, | |
| jobs: List[Any], | |
| current_time: float, | |
| zone_id: int, | |
| ) -> List[Any]: | |
| """Apply current heuristic, potentially re-evaluating first. | |
| This is the main entry point called by the simulator's heuristic_fn. | |
| Re-evaluates every 15 min or on disruption events. | |
| """ | |
| from src.heuristics import ( | |
| fifo_dispatch, priority_edd_dispatch, critical_ratio_dispatch, | |
| atc_dispatch, wspt_dispatch, slack_dispatch, | |
| ) | |
| dispatch_fns: Dict[str, Callable] = { | |
| "fifo": fifo_dispatch, | |
| "priority_edd": priority_edd_dispatch, | |
| "critical_ratio": critical_ratio_dispatch, | |
| "atc": atc_dispatch, | |
| "wspt": wspt_dispatch, | |
| "slack": slack_dispatch, | |
| } | |
| if not jobs: | |
| return jobs | |
| # Re-evaluate if needed (time-based or event-triggered) | |
| if self._sim_state is not None and self._should_reevaluate(current_time): | |
| self._reevaluate(current_time) | |
| # Starvation prevention: force-promote any job waiting >60 min | |
| fn = dispatch_fns.get(self._current_heuristic, fifo_dispatch) | |
| ordered = fn(jobs, current_time, zone_id) | |
| ordered = self._apply_starvation_prevention(ordered, current_time) | |
| return ordered | |
| def __call__(self, jobs: List[Any], current_time: float, zone_id: int) -> List[Any]: | |
| """Callable interface (same as dispatch).""" | |
| return self.dispatch(jobs, current_time, zone_id) | |
| # ------------------------------------------------------------------ | |
| # Re-evaluation logic | |
| # ------------------------------------------------------------------ | |
| def _should_reevaluate(self, now: float) -> bool: | |
| """Return True if we should re-evaluate the heuristic selection.""" | |
| if self._sim_state is None: | |
| return False | |
| # Time-based: every 15 minutes | |
| if now - self._last_eval_time >= self.EVAL_INTERVAL: | |
| return True | |
| # Event: breakdown count changed | |
| n_broken = self._sim_state.get("n_broken_stations", 0) | |
| if n_broken != self._last_breakdown_count: | |
| return True | |
| # Event: lunch state changed | |
| lunch = self._sim_state.get("lunch_active", False) | |
| if lunch != self._last_lunch_state: | |
| return True | |
| return False | |
| def _reevaluate(self, now: float) -> None: | |
| """Perform ML evaluation and decide whether to switch heuristic.""" | |
| if self._sim_state is None: | |
| return | |
| self._last_eval_time = now | |
| self._last_breakdown_count = self._sim_state.get("n_broken_stations", 0) | |
| self._last_lunch_state = self._sim_state.get("lunch_active", False) | |
| # Extract features | |
| try: | |
| features = self._fe.extract_scenario_features(self._sim_state) | |
| except Exception as e: | |
| logger.warning("Feature extraction failed: %s", e) | |
| return | |
| # Check guardrails first | |
| guardrail = self._check_guardrails(features) | |
| if guardrail is not None: | |
| # Guardrail triggered — record and switch if needed | |
| switched = guardrail != self._current_heuristic | |
| plain = f"Guardrail active: {guardrail.replace('guardrail_', '')}. Using {guardrail} as safe default." | |
| probas = {h: (1.0 if h == guardrail else 0.0) for h in self.HEURISTIC_MAP.values()} | |
| top_features = self._get_top_features(features, n=5) | |
| reason_map = { | |
| "fifo": "guardrail_trivial", | |
| "atc": "guardrail_overload" if self._sim_state.get("zone_utilization", {}) else "guardrail_ood", | |
| } | |
| reason = reason_map.get(guardrail, f"guardrail_{guardrail}") | |
| self.switching_log.record( | |
| time=now, | |
| features=features.tolist(), | |
| probabilities=probas, | |
| selected=guardrail, | |
| switched=switched, | |
| reason=reason, | |
| confidence=1.0, | |
| top_features=top_features, | |
| plain_english=f"Guardrail active. Using {self.HEURISTIC_LABELS.get(guardrail, guardrail)} as safe default.", | |
| ) | |
| self._current_heuristic = guardrail | |
| self._current_confidence = 1.0 | |
| self._current_from_guardrail = True | |
| return | |
| # ML prediction | |
| try: | |
| X = features.reshape(1, -1) | |
| probas_arr = self._model.predict_proba(X)[0] | |
| new_idx = int(np.argmax(probas_arr)) | |
| new_heuristic = self.HEURISTIC_MAP.get(new_idx, "fifo") | |
| new_confidence = float(probas_arr[new_idx]) | |
| probas_dict = { | |
| self.HEURISTIC_MAP[i]: float(p) | |
| for i, p in enumerate(probas_arr) | |
| if i in self.HEURISTIC_MAP | |
| } | |
| except Exception as e: | |
| logger.warning("ML prediction failed: %s", e) | |
| return | |
| # Relative-margin hysteresis: switch only if the new heuristic's probability | |
| # exceeds the current × (1 + HYSTERESIS_MARGIN). This is calibration-invariant | |
| # across RF (broad probs) and XGB (sharp probs), unlike an additive threshold. | |
| # Bypassed when current was forced by a guardrail (prevents lock-in on FIFO | |
| # at t=0 when system was empty). | |
| if (not self._current_from_guardrail | |
| and new_heuristic != self._current_heuristic | |
| and new_confidence < self._current_confidence * (1.0 + self.HYSTERESIS_MARGIN)): | |
| # Blocked by hysteresis | |
| top_features = self._get_top_features(features, n=5) | |
| self.switching_log.record( | |
| time=now, | |
| features=features.tolist(), | |
| probabilities=probas_dict, | |
| selected=self._current_heuristic, | |
| switched=False, | |
| reason="hysteresis_blocked", | |
| confidence=new_confidence, | |
| top_features=top_features, | |
| plain_english=( | |
| f"ML suggests {self.HEURISTIC_LABELS.get(new_heuristic, new_heuristic)} " | |
| f"({new_confidence:.0%} confident) but hysteresis threshold not met. " | |
| f"Keeping {self.HEURISTIC_LABELS.get(self._current_heuristic, self._current_heuristic)}." | |
| ), | |
| ) | |
| return | |
| # Switch (or keep) accepted | |
| switched = new_heuristic != self._current_heuristic | |
| top_features = self._get_top_features(features, n=5) | |
| plain_english = self._generate_explanation(features, new_heuristic, "ml_decision", probas_dict) | |
| self.switching_log.record( | |
| time=now, | |
| features=features.tolist(), | |
| probabilities=probas_dict, | |
| selected=new_heuristic, | |
| switched=switched, | |
| reason="ml_decision", | |
| confidence=new_confidence, | |
| top_features=top_features, | |
| plain_english=plain_english, | |
| ) | |
| self._current_heuristic = new_heuristic | |
| self._current_confidence = new_confidence | |
| self._current_from_guardrail = False | |
| def _check_guardrails(self, features: np.ndarray) -> Optional[str]: | |
| """Check edge-case guardrails. Returns heuristic name or None.""" | |
| from src.features import SCENARIO_FEATURE_NAMES | |
| feat_dict = dict(zip(SCENARIO_FEATURE_NAMES, features.tolist())) | |
| # Guardrail 1: Trivial load | |
| n_orders = feat_dict.get("n_orders_in_system", 0) | |
| if n_orders < self.TRIVIAL_LOAD: | |
| return "fifo" | |
| # Guardrail 2: Overload | |
| util_avg = feat_dict.get("zone_utilization_avg", 0.0) | |
| if util_avg > self.OVERLOAD_THRESHOLD: | |
| return "atc" | |
| # Guardrail 3: OOD detection | |
| if self._fe._feature_ranges is not None: | |
| if self._fe.is_out_of_distribution(features, tolerance=0.10): | |
| return "atc" | |
| return None | |
| def _apply_starvation_prevention( | |
| self, | |
| jobs: List[Any], | |
| current_time: float, | |
| ) -> List[Any]: | |
| """Force-promote jobs that have been waiting >60 minutes. | |
| Moves starving jobs to the front of the queue regardless of heuristic. | |
| """ | |
| starving = [j for j in jobs if (current_time - j.arrival_time) > self.STARVATION_LIMIT] | |
| non_starving = [j for j in jobs if j not in starving] | |
| return starving + non_starving | |
| def _get_top_features(self, features: np.ndarray, n: int = 5) -> List[Dict[str, Any]]: | |
| """Return top-n features by importance with current values.""" | |
| from src.features import SCENARIO_FEATURE_NAMES | |
| feat_names = self._feature_names or SCENARIO_FEATURE_NAMES | |
| if self._feature_importances is not None: | |
| top_idx = np.argsort(self._feature_importances)[::-1][:n] | |
| else: | |
| top_idx = list(range(min(n, len(feat_names)))) | |
| result = [] | |
| for i in top_idx: | |
| if i < len(feat_names) and i < len(features): | |
| result.append({ | |
| "name": feat_names[i], | |
| "value": round(float(features[i]), 4), | |
| "importance": round(float(self._feature_importances[i]), 4) | |
| if self._feature_importances is not None else 0.0, | |
| }) | |
| return result | |
| def _generate_explanation( | |
| self, | |
| features: np.ndarray, | |
| heuristic: str, | |
| reason: str, | |
| probas: Dict[str, float], | |
| ) -> str: | |
| """Generate a plain-English explanation for THIS specific decision. | |
| Rather than citing the globally most-important feature (which would | |
| be identical across every decision), we pick the feature whose | |
| per-decision contribution is highest. Contribution is approximated as | |
| importance × |z-score of current value against training range|. | |
| """ | |
| from src.features import SCENARIO_FEATURE_NAMES | |
| feat_names = self._feature_names or list(SCENARIO_FEATURE_NAMES) | |
| feat_dict = dict(zip(feat_names, features.tolist())) | |
| label = self.HEURISTIC_LABELS.get(heuristic, heuristic) | |
| confidence = probas.get(heuristic, 0.0) | |
| # Try to find a per-decision salient feature that has an explanation | |
| # template for this heuristic. | |
| if self._feature_importances is not None and len(feat_names) > 0: | |
| ranges = getattr(self._fe, "_feature_ranges", None) or {} | |
| # Compute a salience score per feature: importance × normalized deviation | |
| salience = np.zeros(len(feat_names), dtype=float) | |
| for i, name in enumerate(feat_names): | |
| if i >= len(features) or i >= len(self._feature_importances): | |
| continue | |
| val = float(features[i]) | |
| imp = float(self._feature_importances[i]) | |
| lo_hi = ranges.get(name) | |
| if lo_hi and lo_hi[1] > lo_hi[0]: | |
| mid = 0.5 * (lo_hi[0] + lo_hi[1]) | |
| half = 0.5 * (lo_hi[1] - lo_hi[0]) | |
| deviation = abs(val - mid) / max(half, 1e-6) | |
| else: | |
| deviation = 1.0 # no range info -> fall back to importance only | |
| salience[i] = imp * (0.5 + deviation) # floor keeps importance relevant | |
| # Prefer features that have a template for this heuristic | |
| ranked = np.argsort(salience)[::-1] | |
| for idx in ranked[:8]: # look at top 8 salient features | |
| if idx >= len(feat_names): | |
| continue | |
| fname = feat_names[idx] | |
| key = (heuristic, fname) | |
| if key in self._EXPLANATION_MAP: | |
| reason_str = self._EXPLANATION_MAP[key] | |
| val = feat_dict.get(fname, 0.0) | |
| return ( | |
| f"DAHS selected {label} ({confidence:.0%} confidence) because " | |
| f"{reason_str} ({fname}={val:.2f})." | |
| ) | |
| # No template hit — name the most salient feature generically | |
| if ranked.size > 0: | |
| idx0 = int(ranked[0]) | |
| if idx0 < len(feat_names): | |
| fname = feat_names[idx0] | |
| val = feat_dict.get(fname, 0.0) | |
| return ( | |
| f"DAHS selected {label} with {confidence:.0%} confidence; " | |
| f"the strongest driver for this decision was " | |
| f"{fname}={val:.2f}." | |
| ) | |
| # Generic fallback | |
| return ( | |
| f"DAHS selected {label} with {confidence:.0%} confidence based on " | |
| f"current system state. This is the predicted optimal heuristic for " | |
| f"minimizing weighted tardiness and SLA breaches." | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # HybridPriority (ported from DAHS_1) | |
| # --------------------------------------------------------------------------- | |
| class HybridPriority: | |
| """Wraps a trained GBR priority-predictor regressor.""" | |
| def __init__( | |
| self, | |
| model_path: Union[Path, str], | |
| feature_extractor: Any, | |
| ) -> None: | |
| self.model_path = Path(model_path) | |
| self.feature_extractor = feature_extractor | |
| self._model = joblib.load(self.model_path) | |
| self._sim_state: Optional[Dict[str, Any]] = None | |
| logger.info("HybridPriority loaded model from %s", self.model_path) | |
| def update_state(self, sim_state: Dict[str, Any]) -> None: | |
| self._sim_state = sim_state | |
| def __call__( | |
| self, | |
| jobs: List[Any], | |
| current_time: float, | |
| zone_id: int, | |
| ) -> List[Any]: | |
| """Dispatch jobs by predicted priority score (descending).""" | |
| from src.heuristics import fifo_dispatch | |
| if not jobs: | |
| return jobs | |
| if self._sim_state is None: | |
| return fifo_dispatch(jobs, current_time, zone_id) | |
| try: | |
| sf = self.feature_extractor.extract_scenario_features(self._sim_state) | |
| job_feats = np.stack([ | |
| np.concatenate([sf, self.feature_extractor.extract_job_features(j, self._sim_state)]) | |
| for j in jobs | |
| ]) | |
| predictions = self._model.predict(job_feats) | |
| ranked = sorted(zip(predictions, jobs), key=lambda x: x[0], reverse=True) | |
| return [job for _, job in ranked] | |
| except Exception as exc: | |
| from src.heuristics import fifo_dispatch | |
| logger.warning("HybridPriority error: %s — falling back to FIFO", exc) | |
| return fifo_dispatch(jobs, current_time, zone_id) | |
| # --------------------------------------------------------------------------- | |
| # Rolling-Horizon Fork Oracle (DAHS 2.1) — hard performance guarantee | |
| # --------------------------------------------------------------------------- | |
| class RollingHorizonOracle: | |
| """Pure fork-oracle selector with a mathematical per-window guarantee. | |
| At each EVAL_INTERVAL minutes it clones the simulator via save_state, | |
| runs every heuristic forward for HORIZON minutes using the preserved RNG | |
| (so all forks see identical future arrivals), then picks the argmin of | |
| a composite cost matching the benchmark objective. Because forks are | |
| RNG-deterministic, the argmin per window is an exact oracle; summed | |
| over the day, cumulative cost is mathematically ≤ min-over-heuristics. | |
| Compute cost: 6 forks × HORIZON min × (600 / EVAL_INTERVAL) decisions ≈ | |
| 21,600 sim-min/day for H=90 — a constant multiplier on the base sim time. | |
| Usage: | |
| sim = WarehouseSimulator(seed=..., heuristic_fn=lambda j, t, z: j, ...) | |
| oracle = RollingHorizonOracle() | |
| oracle.attach_simulator(sim) | |
| sim.heuristic_fn = lambda jobs, t, z: oracle.dispatch(jobs, t, z) | |
| sim.run(duration=600.0) | |
| """ | |
| EVAL_INTERVAL = 15.0 | |
| HORIZON = 90.0 # ≥ median job cycle (23 min Olist) × 4 — eliminates myopia | |
| STARVATION_LIMIT = 60.0 | |
| HEURISTIC_NAMES = ["fifo", "priority_edd", "critical_ratio", "atc", "wspt", "slack"] | |
| # Cost weights aligned with benchmark objective (tardiness-dominant) | |
| W_TARD = 0.55 | |
| W_SLA = 0.35 | |
| W_CYC = 0.10 | |
| def __init__(self, ml_model: Optional[Any] = None, feature_extractor: Any = None) -> None: | |
| """Pure oracle when ml_model is None; hybrid (ML prior) when supplied.""" | |
| self._ml_model = ml_model | |
| self._fe = feature_extractor | |
| self._sim: Optional[Any] = None | |
| self._current_heuristic: str = "fifo" | |
| self._last_eval_time: float = -999.0 | |
| self._last_breakdown_count: int = 0 | |
| self._last_lunch_state: bool = False | |
| self.switching_log = SwitchingLog() | |
| def attach_simulator(self, sim: Any) -> None: | |
| """Bind to the main simulator so we can snapshot it for forks.""" | |
| self._sim = sim | |
| def __call__(self, jobs: List[Any], current_time: float, zone_id: int) -> List[Any]: | |
| return self.dispatch(jobs, current_time, zone_id) | |
| def dispatch(self, jobs: List[Any], current_time: float, zone_id: int) -> List[Any]: | |
| from src.heuristics import DISPATCH_MAP, fifo_dispatch | |
| if not jobs: | |
| return jobs | |
| # Re-evaluate every EVAL_INTERVAL minutes or on state-changing events | |
| if self._sim is not None and self._should_reevaluate(current_time): | |
| self._reevaluate(current_time) | |
| fn = DISPATCH_MAP.get(self._current_heuristic, fifo_dispatch) | |
| ordered = fn(jobs, current_time, zone_id) | |
| ordered = self._apply_starvation_prevention(ordered, current_time) | |
| return ordered | |
| # ------------------------------------------------------------------ | |
| # Fork-oracle evaluation | |
| # ------------------------------------------------------------------ | |
| def _should_reevaluate(self, now: float) -> bool: | |
| if self._sim is None: | |
| return False | |
| if now - self._last_eval_time >= self.EVAL_INTERVAL: | |
| return True | |
| # disruption events | |
| n_broken = sum( | |
| 1 for st in getattr(self._sim, "stations", {}).values() | |
| if getattr(st, "is_broken", False) | |
| ) | |
| if n_broken != self._last_breakdown_count: | |
| return True | |
| lunch = getattr(self._sim, "_lunch_active", False) | |
| if lunch != self._last_lunch_state: | |
| return True | |
| return False | |
| def _reevaluate(self, now: float) -> None: | |
| """Fork all heuristics, score, select best. Hard guarantee lives here.""" | |
| from src.heuristics import DISPATCH_MAP | |
| from src.simulator import WarehouseSimulator | |
| self._last_eval_time = now | |
| self._last_breakdown_count = sum( | |
| 1 for st in getattr(self._sim, "stations", {}).values() | |
| if getattr(st, "is_broken", False) | |
| ) | |
| self._last_lunch_state = getattr(self._sim, "_lunch_active", False) | |
| try: | |
| saved = self._sim.save_state() | |
| except Exception as e: | |
| logger.warning("Oracle save_state failed: %s", e) | |
| return | |
| fork_end = now + self.HORIZON | |
| scores: Dict[str, float] = {} | |
| raw: Dict[str, Tuple[float, float, float]] = {} | |
| for heur in self.HEURISTIC_NAMES: | |
| try: | |
| heur_fn = DISPATCH_MAP[heur] | |
| fork = WarehouseSimulator.from_state(saved, heur_fn) | |
| fork.step_to(fork_end) | |
| m = fork.get_partial_metrics(since_time=now) | |
| tard = float(m.total_tardiness) if np.isfinite(m.total_tardiness) else 1e9 | |
| sla = float(m.sla_breach_rate) if np.isfinite(m.sla_breach_rate) else 1.0 | |
| cyc = float(m.avg_cycle_time) if np.isfinite(m.avg_cycle_time) else 1e6 | |
| except Exception as e: | |
| logger.warning("Fork for %s failed at t=%.1f: %s", heur, now, e) | |
| tard, sla, cyc = 1e9, 1.0, 1e6 | |
| raw[heur] = (tard, sla, cyc) | |
| # Normalize across heuristics so units are comparable, then composite score | |
| tards = np.array([raw[h][0] for h in self.HEURISTIC_NAMES]) | |
| slas = np.array([raw[h][1] for h in self.HEURISTIC_NAMES]) | |
| cycs = np.array([raw[h][2] for h in self.HEURISTIC_NAMES]) | |
| def _norm(a: np.ndarray) -> np.ndarray: | |
| lo, hi = float(a.min()), float(a.max()) | |
| if hi - lo < 1e-10: | |
| return np.zeros_like(a) | |
| return (a - lo) / (hi - lo) | |
| n_t = _norm(tards); n_s = _norm(slas); n_c = _norm(cycs) | |
| composite = self.W_TARD * n_t + self.W_SLA * n_s + self.W_CYC * n_c | |
| for i, h in enumerate(self.HEURISTIC_NAMES): | |
| scores[h] = float(composite[i]) | |
| # Optional ML prior for tie-breaking (Hybrid mode). Does NOT override | |
| # oracle-chosen winner; only nudges among near-ties. | |
| ml_probs: Dict[str, float] = {} | |
| if self._ml_model is not None and self._fe is not None: | |
| try: | |
| sim_state = self._sim.get_state_snapshot() | |
| feats = self._fe.extract_scenario_features(sim_state) | |
| probs = self._ml_model.predict_proba(feats.reshape(1, -1))[0] | |
| for i, h in enumerate(self.HEURISTIC_NAMES): | |
| if i < len(probs): | |
| ml_probs[h] = float(probs[i]) | |
| except Exception as e: | |
| logger.debug("ML prior failed (non-fatal): %s", e) | |
| # Pick best oracle score; break ties (within 2%) by highest ML probability | |
| sorted_h = sorted(self.HEURISTIC_NAMES, key=lambda h: scores[h]) | |
| best = sorted_h[0] | |
| best_score = scores[best] | |
| if ml_probs: | |
| tied = [h for h in sorted_h if scores[h] - best_score < 0.02] | |
| if len(tied) > 1: | |
| best = max(tied, key=lambda h: ml_probs.get(h, 0.0)) | |
| switched = best != self._current_heuristic | |
| self.switching_log.record( | |
| time=now, | |
| features=[float(raw[h][0]) for h in self.HEURISTIC_NAMES], | |
| probabilities={h: round(scores[h], 4) for h in self.HEURISTIC_NAMES}, | |
| selected=best, | |
| switched=switched, | |
| reason="oracle_fork" if not ml_probs else "hybrid_oracle", | |
| confidence=1.0 - best_score, # lower composite → higher confidence | |
| top_features=[ | |
| {"name": f"oracle_tard_{h}", "value": round(raw[h][0], 2), "importance": 1.0} | |
| for h in self.HEURISTIC_NAMES | |
| ], | |
| plain_english=( | |
| f"Oracle fork: {best} wins next {int(self.HORIZON)}-min horizon " | |
| f"(composite score {best_score:.3f})." | |
| ), | |
| ) | |
| self._current_heuristic = best | |
| def _apply_starvation_prevention(self, jobs: List[Any], current_time: float) -> List[Any]: | |
| starving = [j for j in jobs if (current_time - j.arrival_time) > self.STARVATION_LIMIT] | |
| non_starving = [j for j in jobs if j not in starving] | |
| return starving + non_starving | |
| # --------------------------------------------------------------------------- | |
| # Factory helpers | |
| # --------------------------------------------------------------------------- | |
| def load_batchwise_selector( | |
| model_name: str = "rf", | |
| feature_extractor: Any = None, | |
| ) -> BatchwiseSelector: | |
| """Load a BatchwiseSelector for a given classifier variant. | |
| Parameters | |
| ---------- | |
| model_name : str | |
| One of "dt", "rf", "xgb". | |
| feature_extractor : FeatureExtractor | |
| Feature extraction instance. | |
| """ | |
| import json | |
| if feature_extractor is None: | |
| from src.features import FeatureExtractor | |
| feature_extractor = FeatureExtractor() | |
| path = MODELS_DIR / f"selector_{model_name}.joblib" | |
| if not path.exists(): | |
| raise FileNotFoundError(f"Model not found: {path}") | |
| model = joblib.load(path) | |
| model_hash = getattr(model, "_dahs_run_hash", None) | |
| # Load feature importances if available | |
| feature_importances = None | |
| feature_names = None | |
| names_meta: Dict[str, Any] = {} | |
| try: | |
| feature_names_path = MODELS_DIR / "feature_names.json" | |
| if feature_names_path.exists(): | |
| with open(feature_names_path) as f: | |
| names_data = json.load(f) | |
| if isinstance(names_data, dict) and "features" in names_data: | |
| names_meta = names_data.get("_meta", {}) | |
| feature_names = [d["name"] for d in names_data["features"]] | |
| else: | |
| feature_names = [d["name"] for d in names_data] | |
| if hasattr(model, "feature_importances_"): | |
| feature_importances = model.feature_importances_ | |
| except Exception as exc: | |
| logger.warning("Failed to load feature_names.json: %s", exc) | |
| # Load feature ranges for OOD detection | |
| ranges_meta: Dict[str, Any] = {} | |
| try: | |
| ranges_path = MODELS_DIR / "feature_ranges.json" | |
| if ranges_path.exists(): | |
| feature_extractor.load_feature_ranges(ranges_path) | |
| ranges_meta = getattr(feature_extractor, "_feature_ranges_meta", {}) or {} | |
| except Exception as exc: | |
| logger.warning("Failed to load feature_ranges.json: %s", exc) | |
| # Validate that all artifacts came from the same training run. Legacy | |
| # artifacts (model_hash is None) are tolerated for backwards compatibility, | |
| # but any present-and-disagreeing hashes raise loudly — a mismatch means | |
| # someone retrained without regenerating sidecars and the OOD guardrail | |
| # would otherwise apply stale ranges. | |
| artifact_hashes = { | |
| "model": model_hash, | |
| "feature_ranges": ranges_meta.get("run_hash"), | |
| "feature_names": names_meta.get("run_hash"), | |
| } | |
| present = {k: v for k, v in artifact_hashes.items() if v is not None} | |
| if len(set(present.values())) > 1: | |
| raise RuntimeError( | |
| "DAHS model/artifact hash mismatch — re-run scripts/run_pipeline.py " | |
| f"to regenerate them in lockstep. Hashes: {artifact_hashes}" | |
| ) | |
| if feature_names is not None and hasattr(model, "n_features_in_"): | |
| if model.n_features_in_ != len(feature_names): | |
| raise RuntimeError( | |
| f"Model expects {model.n_features_in_} features but " | |
| f"feature_names.json has {len(feature_names)}. Retrain." | |
| ) | |
| return BatchwiseSelector( | |
| model=model, | |
| feature_extractor=feature_extractor, | |
| feature_importances=feature_importances, | |
| feature_names=feature_names, | |
| ) | |
| def load_hybrid_priority(feature_extractor: Any = None) -> HybridPriority: | |
| """Load the GBR-based HybridPriority scheduler.""" | |
| if feature_extractor is None: | |
| from src.features import FeatureExtractor | |
| feature_extractor = FeatureExtractor() | |
| path = MODELS_DIR / "priority_gbr.joblib" | |
| return HybridPriority(model_path=path, feature_extractor=feature_extractor) | |