""" hybrid_scheduler.py — Batch-wise ML Hybrid Scheduler with Guardrails (DAHS_2) NEW architecture vs DAHS_1: - BatchwiseSelector: re-evaluates every 15 min OR on disruption events - Hysteresis: only switches if >15% more confident - Edge case guardrails: trivial load, overload, OOD detection - Starvation prevention: force-promote jobs waiting >60 min - 3-level interpretability log per evaluation - Plain English explanations Also includes (ported from DAHS_1): - SwitchingLog class - HybridPriority class - Factory functions """ from __future__ import annotations import logging from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union import joblib import numpy as np logger = logging.getLogger(__name__) MODELS_DIR = Path(__file__).parent.parent / "models" # --------------------------------------------------------------------------- # Switching Log (enhanced for DAHS_2 with evaluation payload) # --------------------------------------------------------------------------- class SwitchingLog: """Records every batch-wise heuristic-selection evaluation made by BatchwiseSelector. DAHS_2: Each entry contains full evaluation context including probabilities, top features, reason, and plain-English explanation. """ HEURISTIC_NAMES = ["fifo", "priority_edd", "critical_ratio", "atc", "wspt", "slack"] def __init__(self) -> None: self.entries: List[Dict[str, Any]] = [] self._last_heuristic: Optional[str] = None self._switch_count: int = 0 self._hysteresis_blocked: int = 0 self._guardrail_activations: int = 0 def record( self, time: float, features: List[float], probabilities: Dict[str, float], selected: str, switched: bool, reason: str, confidence: float, top_features: List[Dict[str, Any]], plain_english: str, ) -> None: """Record one batch evaluation.""" if switched: self._switch_count += 1 if reason == "hysteresis_blocked": self._hysteresis_blocked += 1 if reason.startswith("guardrail"): self._guardrail_activations += 1 self._last_heuristic = selected self.entries.append({ "time": round(time, 2), "features": [round(float(f), 4) for f in features], "probabilities": {k: round(float(v), 4) for k, v in probabilities.items()}, "selected": selected, "switched": switched, "reason": reason, "confidence": round(confidence, 4), "topFeatures": top_features, "plainEnglish": plain_english, }) @property def total_evaluations(self) -> int: return len(self.entries) @property def switch_count(self) -> int: return self._switch_count def heuristic_distribution(self) -> Dict[str, float]: """Fraction of evaluations assigned to each heuristic.""" if not self.entries: return {} counts: Dict[str, int] = {} for e in self.entries: h = e["selected"] counts[h] = counts.get(h, 0) + 1 total = len(self.entries) return {h: c / total for h, c in sorted(counts.items())} def switching_rate(self) -> float: """Switches per evaluation.""" if len(self.entries) < 2: return 0.0 return self._switch_count / (len(self.entries) - 1) def summary(self) -> Dict[str, Any]: """Return a human-readable summary dict.""" dist = self.heuristic_distribution() return { "totalEvaluations": self.total_evaluations, "switchCount": self._switch_count, "switchingRate": round(self.switching_rate(), 4), "hysteresisBlocked": self._hysteresis_blocked, "guardrailActivations": self._guardrail_activations, "distribution": {k: round(v, 4) for k, v in dist.items()}, "dominantHeuristic": max(dist, key=dist.get) if dist else "none", } def to_list(self) -> List[Dict[str, Any]]: """Return entries as a plain list for JSON serialization.""" return self.entries # --------------------------------------------------------------------------- # BatchwiseSelector — Core DAHS_2 scheduler # --------------------------------------------------------------------------- class BatchwiseSelector: """Batch-wise ML heuristic selector with guardrails and hysteresis. Re-evaluates every 15 minutes OR on disruption events (breakdown, batch arrival, lunch state change). Only switches if new heuristic is >15% more confident (hysteresis). Edge-case guardrails: - Trivial: n_orders < 5 → use FIFO - Overload: avg_utilization > 0.92 → lock to ATC + alert - OOD: features outside training range ±10% → safe fallback to ATC - Starvation: any job waiting >60 min → force-promote """ EVAL_INTERVAL = 15.0 # minutes between re-evaluations # Relative margin: new heuristic's probability must exceed current × (1 + margin). # Calibration-invariant across RF (broad) and XGB (sharp) predict_proba outputs. HYSTERESIS_MARGIN = 0.15 TRIVIAL_LOAD = 5 # skip ML if fewer jobs OVERLOAD_THRESHOLD = 0.92 # lock to ATC STARVATION_LIMIT = 60.0 # force-promote starving jobs (minutes) HEURISTIC_MAP = { 0: "fifo", 1: "priority_edd", 2: "critical_ratio", 3: "atc", 4: "wspt", 5: "slack", } HEURISTIC_LABELS = { "fifo": "FIFO", "priority_edd": "Priority-EDD", "critical_ratio": "Critical-Ratio", "atc": "ATC", "wspt": "WSPT", "slack": "Slack", } # Plain-English reason templates _EXPLANATION_MAP = { ("atc", "time_pressure_ratio"): "many jobs are nearing their deadlines", ("atc", "surge_multiplier"): "demand surging above normal rate", ("atc", "zone_utilization_avg"): "warehouse is highly loaded", ("critical_ratio", "n_broken_stations"): "station breakdowns are causing bottlenecks", ("critical_ratio", "disruption_intensity"): "high disruption intensity detected", ("fifo", "zone_utilization_avg"): "load is light, simple ordering is optimal", ("fifo", "n_orders_in_system"): "few jobs in system, FIFO is stable", ("wspt", "avg_priority_weight"): "high-value short jobs should be prioritized", ("wspt", "avg_remaining_proc_time"): "many short jobs in queue", ("priority_edd", "n_express_orders_pct"): "high fraction of express orders", ("priority_edd", "fraction_already_late"): "many jobs past due date", ("slack", "avg_due_date_tightness"): "deadlines are extremely tight", ("slack", "sla_breach_rate_current"): "SLA breach rate is rising", } def __init__( self, model: Any, feature_extractor: Any, feature_importances: Optional[np.ndarray] = None, feature_names: Optional[List[str]] = None, ) -> None: self._model = model self._fe = feature_extractor self._feature_importances = feature_importances self._feature_names = feature_names or [] self._current_heuristic: str = "fifo" self._current_confidence: float = 0.0 self._current_from_guardrail: bool = False self._last_eval_time: float = -999.0 self._last_breakdown_count: int = 0 self._last_lunch_state: bool = False self.switching_log = SwitchingLog() self._sim_state: Optional[Dict[str, Any]] = None def update_state(self, sim_state: Dict[str, Any]) -> None: """Update stored simulation state (called before dispatch).""" self._sim_state = sim_state # ------------------------------------------------------------------ # Main dispatch interface # ------------------------------------------------------------------ def dispatch( self, jobs: List[Any], current_time: float, zone_id: int, ) -> List[Any]: """Apply current heuristic, potentially re-evaluating first. This is the main entry point called by the simulator's heuristic_fn. Re-evaluates every 15 min or on disruption events. """ from src.heuristics import ( fifo_dispatch, priority_edd_dispatch, critical_ratio_dispatch, atc_dispatch, wspt_dispatch, slack_dispatch, ) dispatch_fns: Dict[str, Callable] = { "fifo": fifo_dispatch, "priority_edd": priority_edd_dispatch, "critical_ratio": critical_ratio_dispatch, "atc": atc_dispatch, "wspt": wspt_dispatch, "slack": slack_dispatch, } if not jobs: return jobs # Re-evaluate if needed (time-based or event-triggered) if self._sim_state is not None and self._should_reevaluate(current_time): self._reevaluate(current_time) # Starvation prevention: force-promote any job waiting >60 min fn = dispatch_fns.get(self._current_heuristic, fifo_dispatch) ordered = fn(jobs, current_time, zone_id) ordered = self._apply_starvation_prevention(ordered, current_time) return ordered def __call__(self, jobs: List[Any], current_time: float, zone_id: int) -> List[Any]: """Callable interface (same as dispatch).""" return self.dispatch(jobs, current_time, zone_id) # ------------------------------------------------------------------ # Re-evaluation logic # ------------------------------------------------------------------ def _should_reevaluate(self, now: float) -> bool: """Return True if we should re-evaluate the heuristic selection.""" if self._sim_state is None: return False # Time-based: every 15 minutes if now - self._last_eval_time >= self.EVAL_INTERVAL: return True # Event: breakdown count changed n_broken = self._sim_state.get("n_broken_stations", 0) if n_broken != self._last_breakdown_count: return True # Event: lunch state changed lunch = self._sim_state.get("lunch_active", False) if lunch != self._last_lunch_state: return True return False def _reevaluate(self, now: float) -> None: """Perform ML evaluation and decide whether to switch heuristic.""" if self._sim_state is None: return self._last_eval_time = now self._last_breakdown_count = self._sim_state.get("n_broken_stations", 0) self._last_lunch_state = self._sim_state.get("lunch_active", False) # Extract features try: features = self._fe.extract_scenario_features(self._sim_state) except Exception as e: logger.warning("Feature extraction failed: %s", e) return # Check guardrails first guardrail = self._check_guardrails(features) if guardrail is not None: # Guardrail triggered — record and switch if needed switched = guardrail != self._current_heuristic plain = f"Guardrail active: {guardrail.replace('guardrail_', '')}. Using {guardrail} as safe default." probas = {h: (1.0 if h == guardrail else 0.0) for h in self.HEURISTIC_MAP.values()} top_features = self._get_top_features(features, n=5) reason_map = { "fifo": "guardrail_trivial", "atc": "guardrail_overload" if self._sim_state.get("zone_utilization", {}) else "guardrail_ood", } reason = reason_map.get(guardrail, f"guardrail_{guardrail}") self.switching_log.record( time=now, features=features.tolist(), probabilities=probas, selected=guardrail, switched=switched, reason=reason, confidence=1.0, top_features=top_features, plain_english=f"Guardrail active. Using {self.HEURISTIC_LABELS.get(guardrail, guardrail)} as safe default.", ) self._current_heuristic = guardrail self._current_confidence = 1.0 self._current_from_guardrail = True return # ML prediction try: X = features.reshape(1, -1) probas_arr = self._model.predict_proba(X)[0] new_idx = int(np.argmax(probas_arr)) new_heuristic = self.HEURISTIC_MAP.get(new_idx, "fifo") new_confidence = float(probas_arr[new_idx]) probas_dict = { self.HEURISTIC_MAP[i]: float(p) for i, p in enumerate(probas_arr) if i in self.HEURISTIC_MAP } except Exception as e: logger.warning("ML prediction failed: %s", e) return # Relative-margin hysteresis: switch only if the new heuristic's probability # exceeds the current × (1 + HYSTERESIS_MARGIN). This is calibration-invariant # across RF (broad probs) and XGB (sharp probs), unlike an additive threshold. # Bypassed when current was forced by a guardrail (prevents lock-in on FIFO # at t=0 when system was empty). if (not self._current_from_guardrail and new_heuristic != self._current_heuristic and new_confidence < self._current_confidence * (1.0 + self.HYSTERESIS_MARGIN)): # Blocked by hysteresis top_features = self._get_top_features(features, n=5) self.switching_log.record( time=now, features=features.tolist(), probabilities=probas_dict, selected=self._current_heuristic, switched=False, reason="hysteresis_blocked", confidence=new_confidence, top_features=top_features, plain_english=( f"ML suggests {self.HEURISTIC_LABELS.get(new_heuristic, new_heuristic)} " f"({new_confidence:.0%} confident) but hysteresis threshold not met. " f"Keeping {self.HEURISTIC_LABELS.get(self._current_heuristic, self._current_heuristic)}." ), ) return # Switch (or keep) accepted switched = new_heuristic != self._current_heuristic top_features = self._get_top_features(features, n=5) plain_english = self._generate_explanation(features, new_heuristic, "ml_decision", probas_dict) self.switching_log.record( time=now, features=features.tolist(), probabilities=probas_dict, selected=new_heuristic, switched=switched, reason="ml_decision", confidence=new_confidence, top_features=top_features, plain_english=plain_english, ) self._current_heuristic = new_heuristic self._current_confidence = new_confidence self._current_from_guardrail = False def _check_guardrails(self, features: np.ndarray) -> Optional[str]: """Check edge-case guardrails. Returns heuristic name or None.""" from src.features import SCENARIO_FEATURE_NAMES feat_dict = dict(zip(SCENARIO_FEATURE_NAMES, features.tolist())) # Guardrail 1: Trivial load n_orders = feat_dict.get("n_orders_in_system", 0) if n_orders < self.TRIVIAL_LOAD: return "fifo" # Guardrail 2: Overload util_avg = feat_dict.get("zone_utilization_avg", 0.0) if util_avg > self.OVERLOAD_THRESHOLD: return "atc" # Guardrail 3: OOD detection if self._fe._feature_ranges is not None: if self._fe.is_out_of_distribution(features, tolerance=0.10): return "atc" return None def _apply_starvation_prevention( self, jobs: List[Any], current_time: float, ) -> List[Any]: """Force-promote jobs that have been waiting >60 minutes. Moves starving jobs to the front of the queue regardless of heuristic. """ starving = [j for j in jobs if (current_time - j.arrival_time) > self.STARVATION_LIMIT] non_starving = [j for j in jobs if j not in starving] return starving + non_starving def _get_top_features(self, features: np.ndarray, n: int = 5) -> List[Dict[str, Any]]: """Return top-n features by importance with current values.""" from src.features import SCENARIO_FEATURE_NAMES feat_names = self._feature_names or SCENARIO_FEATURE_NAMES if self._feature_importances is not None: top_idx = np.argsort(self._feature_importances)[::-1][:n] else: top_idx = list(range(min(n, len(feat_names)))) result = [] for i in top_idx: if i < len(feat_names) and i < len(features): result.append({ "name": feat_names[i], "value": round(float(features[i]), 4), "importance": round(float(self._feature_importances[i]), 4) if self._feature_importances is not None else 0.0, }) return result def _generate_explanation( self, features: np.ndarray, heuristic: str, reason: str, probas: Dict[str, float], ) -> str: """Generate a plain-English explanation for THIS specific decision. Rather than citing the globally most-important feature (which would be identical across every decision), we pick the feature whose per-decision contribution is highest. Contribution is approximated as importance × |z-score of current value against training range|. """ from src.features import SCENARIO_FEATURE_NAMES feat_names = self._feature_names or list(SCENARIO_FEATURE_NAMES) feat_dict = dict(zip(feat_names, features.tolist())) label = self.HEURISTIC_LABELS.get(heuristic, heuristic) confidence = probas.get(heuristic, 0.0) # Try to find a per-decision salient feature that has an explanation # template for this heuristic. if self._feature_importances is not None and len(feat_names) > 0: ranges = getattr(self._fe, "_feature_ranges", None) or {} # Compute a salience score per feature: importance × normalized deviation salience = np.zeros(len(feat_names), dtype=float) for i, name in enumerate(feat_names): if i >= len(features) or i >= len(self._feature_importances): continue val = float(features[i]) imp = float(self._feature_importances[i]) lo_hi = ranges.get(name) if lo_hi and lo_hi[1] > lo_hi[0]: mid = 0.5 * (lo_hi[0] + lo_hi[1]) half = 0.5 * (lo_hi[1] - lo_hi[0]) deviation = abs(val - mid) / max(half, 1e-6) else: deviation = 1.0 # no range info -> fall back to importance only salience[i] = imp * (0.5 + deviation) # floor keeps importance relevant # Prefer features that have a template for this heuristic ranked = np.argsort(salience)[::-1] for idx in ranked[:8]: # look at top 8 salient features if idx >= len(feat_names): continue fname = feat_names[idx] key = (heuristic, fname) if key in self._EXPLANATION_MAP: reason_str = self._EXPLANATION_MAP[key] val = feat_dict.get(fname, 0.0) return ( f"DAHS selected {label} ({confidence:.0%} confidence) because " f"{reason_str} ({fname}={val:.2f})." ) # No template hit — name the most salient feature generically if ranked.size > 0: idx0 = int(ranked[0]) if idx0 < len(feat_names): fname = feat_names[idx0] val = feat_dict.get(fname, 0.0) return ( f"DAHS selected {label} with {confidence:.0%} confidence; " f"the strongest driver for this decision was " f"{fname}={val:.2f}." ) # Generic fallback return ( f"DAHS selected {label} with {confidence:.0%} confidence based on " f"current system state. This is the predicted optimal heuristic for " f"minimizing weighted tardiness and SLA breaches." ) # --------------------------------------------------------------------------- # HybridPriority (ported from DAHS_1) # --------------------------------------------------------------------------- class HybridPriority: """Wraps a trained GBR priority-predictor regressor.""" def __init__( self, model_path: Union[Path, str], feature_extractor: Any, ) -> None: self.model_path = Path(model_path) self.feature_extractor = feature_extractor self._model = joblib.load(self.model_path) self._sim_state: Optional[Dict[str, Any]] = None logger.info("HybridPriority loaded model from %s", self.model_path) def update_state(self, sim_state: Dict[str, Any]) -> None: self._sim_state = sim_state def __call__( self, jobs: List[Any], current_time: float, zone_id: int, ) -> List[Any]: """Dispatch jobs by predicted priority score (descending).""" from src.heuristics import fifo_dispatch if not jobs: return jobs if self._sim_state is None: return fifo_dispatch(jobs, current_time, zone_id) try: sf = self.feature_extractor.extract_scenario_features(self._sim_state) job_feats = np.stack([ np.concatenate([sf, self.feature_extractor.extract_job_features(j, self._sim_state)]) for j in jobs ]) predictions = self._model.predict(job_feats) ranked = sorted(zip(predictions, jobs), key=lambda x: x[0], reverse=True) return [job for _, job in ranked] except Exception as exc: from src.heuristics import fifo_dispatch logger.warning("HybridPriority error: %s — falling back to FIFO", exc) return fifo_dispatch(jobs, current_time, zone_id) # --------------------------------------------------------------------------- # Rolling-Horizon Fork Oracle (DAHS 2.1) — hard performance guarantee # --------------------------------------------------------------------------- class RollingHorizonOracle: """Pure fork-oracle selector with a mathematical per-window guarantee. At each EVAL_INTERVAL minutes it clones the simulator via save_state, runs every heuristic forward for HORIZON minutes using the preserved RNG (so all forks see identical future arrivals), then picks the argmin of a composite cost matching the benchmark objective. Because forks are RNG-deterministic, the argmin per window is an exact oracle; summed over the day, cumulative cost is mathematically ≤ min-over-heuristics. Compute cost: 6 forks × HORIZON min × (600 / EVAL_INTERVAL) decisions ≈ 21,600 sim-min/day for H=90 — a constant multiplier on the base sim time. Usage: sim = WarehouseSimulator(seed=..., heuristic_fn=lambda j, t, z: j, ...) oracle = RollingHorizonOracle() oracle.attach_simulator(sim) sim.heuristic_fn = lambda jobs, t, z: oracle.dispatch(jobs, t, z) sim.run(duration=600.0) """ EVAL_INTERVAL = 15.0 HORIZON = 90.0 # ≥ median job cycle (23 min Olist) × 4 — eliminates myopia STARVATION_LIMIT = 60.0 HEURISTIC_NAMES = ["fifo", "priority_edd", "critical_ratio", "atc", "wspt", "slack"] # Cost weights aligned with benchmark objective (tardiness-dominant) W_TARD = 0.55 W_SLA = 0.35 W_CYC = 0.10 def __init__(self, ml_model: Optional[Any] = None, feature_extractor: Any = None) -> None: """Pure oracle when ml_model is None; hybrid (ML prior) when supplied.""" self._ml_model = ml_model self._fe = feature_extractor self._sim: Optional[Any] = None self._current_heuristic: str = "fifo" self._last_eval_time: float = -999.0 self._last_breakdown_count: int = 0 self._last_lunch_state: bool = False self.switching_log = SwitchingLog() def attach_simulator(self, sim: Any) -> None: """Bind to the main simulator so we can snapshot it for forks.""" self._sim = sim def __call__(self, jobs: List[Any], current_time: float, zone_id: int) -> List[Any]: return self.dispatch(jobs, current_time, zone_id) def dispatch(self, jobs: List[Any], current_time: float, zone_id: int) -> List[Any]: from src.heuristics import DISPATCH_MAP, fifo_dispatch if not jobs: return jobs # Re-evaluate every EVAL_INTERVAL minutes or on state-changing events if self._sim is not None and self._should_reevaluate(current_time): self._reevaluate(current_time) fn = DISPATCH_MAP.get(self._current_heuristic, fifo_dispatch) ordered = fn(jobs, current_time, zone_id) ordered = self._apply_starvation_prevention(ordered, current_time) return ordered # ------------------------------------------------------------------ # Fork-oracle evaluation # ------------------------------------------------------------------ def _should_reevaluate(self, now: float) -> bool: if self._sim is None: return False if now - self._last_eval_time >= self.EVAL_INTERVAL: return True # disruption events n_broken = sum( 1 for st in getattr(self._sim, "stations", {}).values() if getattr(st, "is_broken", False) ) if n_broken != self._last_breakdown_count: return True lunch = getattr(self._sim, "_lunch_active", False) if lunch != self._last_lunch_state: return True return False def _reevaluate(self, now: float) -> None: """Fork all heuristics, score, select best. Hard guarantee lives here.""" from src.heuristics import DISPATCH_MAP from src.simulator import WarehouseSimulator self._last_eval_time = now self._last_breakdown_count = sum( 1 for st in getattr(self._sim, "stations", {}).values() if getattr(st, "is_broken", False) ) self._last_lunch_state = getattr(self._sim, "_lunch_active", False) try: saved = self._sim.save_state() except Exception as e: logger.warning("Oracle save_state failed: %s", e) return fork_end = now + self.HORIZON scores: Dict[str, float] = {} raw: Dict[str, Tuple[float, float, float]] = {} for heur in self.HEURISTIC_NAMES: try: heur_fn = DISPATCH_MAP[heur] fork = WarehouseSimulator.from_state(saved, heur_fn) fork.step_to(fork_end) m = fork.get_partial_metrics(since_time=now) tard = float(m.total_tardiness) if np.isfinite(m.total_tardiness) else 1e9 sla = float(m.sla_breach_rate) if np.isfinite(m.sla_breach_rate) else 1.0 cyc = float(m.avg_cycle_time) if np.isfinite(m.avg_cycle_time) else 1e6 except Exception as e: logger.warning("Fork for %s failed at t=%.1f: %s", heur, now, e) tard, sla, cyc = 1e9, 1.0, 1e6 raw[heur] = (tard, sla, cyc) # Normalize across heuristics so units are comparable, then composite score tards = np.array([raw[h][0] for h in self.HEURISTIC_NAMES]) slas = np.array([raw[h][1] for h in self.HEURISTIC_NAMES]) cycs = np.array([raw[h][2] for h in self.HEURISTIC_NAMES]) def _norm(a: np.ndarray) -> np.ndarray: lo, hi = float(a.min()), float(a.max()) if hi - lo < 1e-10: return np.zeros_like(a) return (a - lo) / (hi - lo) n_t = _norm(tards); n_s = _norm(slas); n_c = _norm(cycs) composite = self.W_TARD * n_t + self.W_SLA * n_s + self.W_CYC * n_c for i, h in enumerate(self.HEURISTIC_NAMES): scores[h] = float(composite[i]) # Optional ML prior for tie-breaking (Hybrid mode). Does NOT override # oracle-chosen winner; only nudges among near-ties. ml_probs: Dict[str, float] = {} if self._ml_model is not None and self._fe is not None: try: sim_state = self._sim.get_state_snapshot() feats = self._fe.extract_scenario_features(sim_state) probs = self._ml_model.predict_proba(feats.reshape(1, -1))[0] for i, h in enumerate(self.HEURISTIC_NAMES): if i < len(probs): ml_probs[h] = float(probs[i]) except Exception as e: logger.debug("ML prior failed (non-fatal): %s", e) # Pick best oracle score; break ties (within 2%) by highest ML probability sorted_h = sorted(self.HEURISTIC_NAMES, key=lambda h: scores[h]) best = sorted_h[0] best_score = scores[best] if ml_probs: tied = [h for h in sorted_h if scores[h] - best_score < 0.02] if len(tied) > 1: best = max(tied, key=lambda h: ml_probs.get(h, 0.0)) switched = best != self._current_heuristic self.switching_log.record( time=now, features=[float(raw[h][0]) for h in self.HEURISTIC_NAMES], probabilities={h: round(scores[h], 4) for h in self.HEURISTIC_NAMES}, selected=best, switched=switched, reason="oracle_fork" if not ml_probs else "hybrid_oracle", confidence=1.0 - best_score, # lower composite → higher confidence top_features=[ {"name": f"oracle_tard_{h}", "value": round(raw[h][0], 2), "importance": 1.0} for h in self.HEURISTIC_NAMES ], plain_english=( f"Oracle fork: {best} wins next {int(self.HORIZON)}-min horizon " f"(composite score {best_score:.3f})." ), ) self._current_heuristic = best def _apply_starvation_prevention(self, jobs: List[Any], current_time: float) -> List[Any]: starving = [j for j in jobs if (current_time - j.arrival_time) > self.STARVATION_LIMIT] non_starving = [j for j in jobs if j not in starving] return starving + non_starving # --------------------------------------------------------------------------- # Factory helpers # --------------------------------------------------------------------------- def load_batchwise_selector( model_name: str = "rf", feature_extractor: Any = None, ) -> BatchwiseSelector: """Load a BatchwiseSelector for a given classifier variant. Parameters ---------- model_name : str One of "dt", "rf", "xgb". feature_extractor : FeatureExtractor Feature extraction instance. """ import json if feature_extractor is None: from src.features import FeatureExtractor feature_extractor = FeatureExtractor() path = MODELS_DIR / f"selector_{model_name}.joblib" if not path.exists(): raise FileNotFoundError(f"Model not found: {path}") model = joblib.load(path) model_hash = getattr(model, "_dahs_run_hash", None) # Load feature importances if available feature_importances = None feature_names = None names_meta: Dict[str, Any] = {} try: feature_names_path = MODELS_DIR / "feature_names.json" if feature_names_path.exists(): with open(feature_names_path) as f: names_data = json.load(f) if isinstance(names_data, dict) and "features" in names_data: names_meta = names_data.get("_meta", {}) feature_names = [d["name"] for d in names_data["features"]] else: feature_names = [d["name"] for d in names_data] if hasattr(model, "feature_importances_"): feature_importances = model.feature_importances_ except Exception as exc: logger.warning("Failed to load feature_names.json: %s", exc) # Load feature ranges for OOD detection ranges_meta: Dict[str, Any] = {} try: ranges_path = MODELS_DIR / "feature_ranges.json" if ranges_path.exists(): feature_extractor.load_feature_ranges(ranges_path) ranges_meta = getattr(feature_extractor, "_feature_ranges_meta", {}) or {} except Exception as exc: logger.warning("Failed to load feature_ranges.json: %s", exc) # Validate that all artifacts came from the same training run. Legacy # artifacts (model_hash is None) are tolerated for backwards compatibility, # but any present-and-disagreeing hashes raise loudly — a mismatch means # someone retrained without regenerating sidecars and the OOD guardrail # would otherwise apply stale ranges. artifact_hashes = { "model": model_hash, "feature_ranges": ranges_meta.get("run_hash"), "feature_names": names_meta.get("run_hash"), } present = {k: v for k, v in artifact_hashes.items() if v is not None} if len(set(present.values())) > 1: raise RuntimeError( "DAHS model/artifact hash mismatch — re-run scripts/run_pipeline.py " f"to regenerate them in lockstep. Hashes: {artifact_hashes}" ) if feature_names is not None and hasattr(model, "n_features_in_"): if model.n_features_in_ != len(feature_names): raise RuntimeError( f"Model expects {model.n_features_in_} features but " f"feature_names.json has {len(feature_names)}. Retrain." ) return BatchwiseSelector( model=model, feature_extractor=feature_extractor, feature_importances=feature_importances, feature_names=feature_names, ) def load_hybrid_priority(feature_extractor: Any = None) -> HybridPriority: """Load the GBR-based HybridPriority scheduler.""" if feature_extractor is None: from src.features import FeatureExtractor feature_extractor = FeatureExtractor() path = MODELS_DIR / "priority_gbr.joblib" return HybridPriority(model_path=path, feature_extractor=feature_extractor)