Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| scripts/foolproof_retrain.py — Failure-tolerant GBR retrain pipeline. | |
| Pipeline: | |
| Step 0: Backup current model -> priority_gbr.backup.joblib | |
| Step 1: Generate targeted preset training data (rotating dispatchers) | |
| Step 2: Augment existing dataset (append, never replace) | |
| Step 3: Train candidate GBR -> priority_gbr.candidate.joblib | |
| Step 4: Verify A: preset benchmark (7 presets) - candidate must hit >= preset_floor wins | |
| Step 5: Verify B: random-seed benchmark (20 seeds) - candidate must hit >= random_floor wins | |
| Step 6: Promote candidate or rollback to backup | |
| Worst-case outcome: original priority_gbr.joblib unchanged. | |
| Usage: | |
| python scripts/foolproof_retrain.py | |
| python scripts/foolproof_retrain.py --preset-floor 7 --random-floor 19 | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import logging | |
| import multiprocessing as mp | |
| import os | |
| import shutil | |
| import sys | |
| import time | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Tuple | |
| import joblib | |
| import numpy as np | |
| import pandas as pd | |
| ROOT = Path(__file__).parent.parent | |
| sys.path.insert(0, str(ROOT)) | |
| # Force UTF-8 stdout on Windows | |
| for _stream in ("stdout", "stderr"): | |
| try: | |
| getattr(sys, _stream).reconfigure(encoding="utf-8", errors="replace") | |
| except Exception: | |
| pass | |
| from src.simulator import WarehouseSimulator | |
| from src.features import FeatureExtractor, SCENARIO_FEATURE_NAMES, JOB_FEATURE_NAMES | |
| from src.heuristics import ( | |
| fifo_dispatch, priority_edd_dispatch, critical_ratio_dispatch, | |
| atc_dispatch, wspt_dispatch, slack_dispatch, | |
| ) | |
| from src.presets import PRESETS, get_preset | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") | |
| logger = logging.getLogger(__name__) | |
| DISPATCH_FNS = { | |
| "fifo": fifo_dispatch, | |
| "priority_edd": priority_edd_dispatch, | |
| "critical_ratio": critical_ratio_dispatch, | |
| "atc": atc_dispatch, | |
| "wspt": wspt_dispatch, | |
| "slack": slack_dispatch, | |
| } | |
| MODELS_DIR = ROOT / "models" | |
| DATA_DIR = ROOT / "data" / "raw" | |
| RESULTS_DIR = ROOT / "results" | |
| LIVE_MODEL = MODELS_DIR / "priority_gbr.joblib" | |
| BACKUP_MODEL = MODELS_DIR / "priority_gbr.backup.joblib" | |
| CANDIDATE_MODEL = MODELS_DIR / "priority_gbr.candidate.joblib" | |
| ORIG_DATA = DATA_DIR / "priority_dataset.csv" | |
| AUG_DATA = DATA_DIR / "priority_dataset_augmented.csv" | |
| # Targeted scenario allocation | |
| PRESET_SCENARIO_BUDGET = { | |
| "Preset-1-FIFO": 300, | |
| "Preset-2-Priority-EDD": 300, | |
| "Preset-3-CR": 300, | |
| "Preset-4-ATC": 1000, # currently losing -> heavy | |
| "Preset-5-WSPT": 1000, # currently losing -> heavy | |
| "Preset-6-Slack": 300, | |
| "Preset-7-RealData": 300, | |
| } | |
| N_POINTS_PER = 12 | |
| N_WORKERS = 4 | |
| # ============================================================================ | |
| # Worker (module-level for Windows spawn compatibility) | |
| # ============================================================================ | |
| def _preset_worker(args: Tuple[int, int, str, str]) -> List[Dict[str, Any]]: | |
| """Run one (seed, preset, dispatcher) scenario, return ~n_points feature rows.""" | |
| seed, n_points, preset_name, dispatcher_name = args | |
| p = get_preset(preset_name) | |
| dispatch_fn = DISPATCH_FNS[dispatcher_name] | |
| fe = FeatureExtractor() | |
| sim = WarehouseSimulator( | |
| seed=seed, | |
| heuristic_fn=dispatch_fn, | |
| feature_extractor=fe, | |
| base_arrival_rate=p.base_arrival_rate, | |
| breakdown_prob=p.breakdown_prob, | |
| batch_arrival_size=p.batch_arrival_size, | |
| lunch_penalty_factor=p.lunch_penalty_factor, | |
| job_type_frequencies=p.job_type_frequencies, | |
| due_date_tightness=p.due_date_tightness, | |
| processing_time_scale=p.processing_time_scale, | |
| ) | |
| sim.run(duration=600.0) | |
| state = sim.get_state_snapshot() | |
| completed = sim.completed_jobs | |
| if not completed: | |
| return [] | |
| _PRIO_W = {"A": 2.0, "B": 1.5, "C": 1.0, "D": 0.8, "E": 3.0} | |
| _DD_OFFSET = {"A": 120, "B": 160, "C": 240, "D": 320, "E": 60} | |
| rng = np.random.default_rng(seed) | |
| sampled = rng.choice(len(completed), | |
| size=min(n_points, len(completed)), replace=False) | |
| rows: List[Dict[str, Any]] = [] | |
| for idx in sampled: | |
| job = completed[int(idx)] | |
| sf = fe.extract_scenario_features(state) | |
| jf = fe.extract_job_features(job, state) | |
| w = _PRIO_W.get(job.job_type, 1.0) | |
| dd_off = _DD_OFFSET.get(job.job_type, 120) | |
| cycle = job.completion_time - job.arrival_time | |
| tard = max(0.0, job.completion_time - job.due_date) | |
| remaining = job.remaining_proc_time() | |
| time_to_due = job.due_date - state["current_time"] | |
| urgency = 1.0 - min(1.0, max(0.0, time_to_due / max(dd_off, 1.0))) | |
| importance = w / 3.0 | |
| efficiency = 1.0 / (1.0 + remaining / 30.0) | |
| delivery_perf = max(0.0, 1.0 - tard / max(dd_off, 1.0)) | |
| score = float(0.30*urgency + 0.25*importance + 0.20*efficiency + 0.25*delivery_perf) | |
| if not np.isfinite(score): | |
| continue | |
| row = { | |
| **{f"sf_{i}": float(v) for i, v in enumerate(sf)}, | |
| **{f"jf_{i}": float(v) for i, v in enumerate(jf)}, | |
| "priority_score": score, | |
| } | |
| rows.append(row) | |
| return rows | |
| # ============================================================================ | |
| # Step 1+2: data generation + augmentation | |
| # ============================================================================ | |
| def generate_augmented_dataset() -> pd.DataFrame: | |
| if not ORIG_DATA.exists(): | |
| raise SystemExit(f"Missing original dataset: {ORIG_DATA}") | |
| logger.info("Loading original dataset: %s", ORIG_DATA) | |
| df_orig = pd.read_csv(ORIG_DATA) | |
| logger.info(" -> %d rows, %d cols", len(df_orig), df_orig.shape[1]) | |
| # Build worker args: rotate dispatchers across seeds within each preset | |
| rotation = ["atc", "wspt", "fifo", "priority_edd", "critical_ratio", "slack"] | |
| args_list: List[Tuple[int, int, str, str]] = [] | |
| seed_base = 50_000 | |
| for preset_name, n_scen in PRESET_SCENARIO_BUDGET.items(): | |
| for k in range(n_scen): | |
| seed = seed_base + k | |
| disp = rotation[k % len(rotation)] | |
| args_list.append((seed, N_POINTS_PER, preset_name, disp)) | |
| seed_base += 100_000 # avoid collisions across presets | |
| total = len(args_list) | |
| logger.info("Generating %d preset scenarios with rotating dispatchers...", total) | |
| new_rows: List[Dict[str, Any]] = [] | |
| t0 = time.time() | |
| ctx = mp.get_context("spawn") | |
| with ctx.Pool(processes=N_WORKERS) as pool: | |
| for i, batch in enumerate(pool.imap_unordered(_preset_worker, args_list), 1): | |
| new_rows.extend(batch) | |
| if i % 100 == 0: | |
| pct = 100 * i / total | |
| elapsed = time.time() - t0 | |
| eta = elapsed * (total - i) / max(i, 1) | |
| logger.info(" progress: %d/%d (%.1f%%) elapsed=%.0fs eta=%.0fs", | |
| i, total, pct, elapsed, eta) | |
| logger.info("Generated %d new rows in %.0fs", len(new_rows), time.time() - t0) | |
| if not new_rows: | |
| raise SystemExit("Preset data generation produced 0 rows -> abort") | |
| df_new = pd.DataFrame(new_rows) | |
| sf_names = {f"sf_{i}": name for i, name in enumerate(SCENARIO_FEATURE_NAMES)} | |
| jf_names = {f"jf_{i}": name for i, name in enumerate(JOB_FEATURE_NAMES)} | |
| df_new.rename(columns={**sf_names, **jf_names}, inplace=True) | |
| df_new = df_new.replace([np.inf, -np.inf], np.nan).dropna() | |
| # Align columns | |
| common_cols = [c for c in df_orig.columns if c in df_new.columns] | |
| if "priority_score" not in common_cols: | |
| common_cols.append("priority_score") | |
| df_orig_a = df_orig[common_cols] | |
| df_new_a = df_new[common_cols] | |
| df_aug = pd.concat([df_orig_a, df_new_a], ignore_index=True) | |
| logger.info("Augmented dataset: %d rows (orig=%d + new=%d)", | |
| len(df_aug), len(df_orig_a), len(df_new_a)) | |
| DATA_DIR.mkdir(parents=True, exist_ok=True) | |
| df_aug.to_csv(AUG_DATA, index=False) | |
| logger.info("Wrote augmented dataset -> %s", AUG_DATA) | |
| return df_aug | |
| # ============================================================================ | |
| # Step 3: train candidate | |
| # ============================================================================ | |
| def train_candidate(df: pd.DataFrame) -> None: | |
| from sklearn.ensemble import GradientBoostingRegressor | |
| from sklearn.metrics import mean_absolute_error, r2_score | |
| from sklearn.model_selection import train_test_split | |
| df = df.replace([np.inf, -np.inf], np.nan).dropna() | |
| feature_cols = [c for c in df.columns if c != "priority_score"] | |
| X = df[feature_cols].values.astype(np.float32) | |
| y = df["priority_score"].values.astype(np.float32) | |
| logger.info("Training data: X=%s y=%s", X.shape, y.shape) | |
| X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.20, random_state=42) | |
| model = GradientBoostingRegressor( | |
| n_estimators=300, max_depth=6, learning_rate=0.05, | |
| subsample=0.8, min_samples_leaf=5, random_state=42, | |
| ) | |
| t0 = time.time() | |
| model.fit(X_tr, y_tr) | |
| logger.info("Fit time: %.1fs", time.time() - t0) | |
| y_hat = model.predict(X_te) | |
| logger.info("Candidate metrics: R2=%.4f MAE=%.4f", | |
| r2_score(y_te, y_hat), mean_absolute_error(y_te, y_hat)) | |
| joblib.dump(model, CANDIDATE_MODEL) | |
| logger.info("Saved candidate -> %s", CANDIDATE_MODEL) | |
| # ============================================================================ | |
| # Step 4: preset benchmark (uses candidate model) | |
| # ============================================================================ | |
| def _make_priority_dispatch(model, fe, sim_ref): | |
| def dispatch(jobs, t, zone_id): | |
| sim = sim_ref[0] | |
| if not jobs or sim is None: | |
| return fifo_dispatch(jobs, t, zone_id) | |
| try: | |
| state = sim.get_state_snapshot() | |
| sf = fe.extract_scenario_features(state) | |
| feats = np.stack([ | |
| np.concatenate([sf, fe.extract_job_features(j, state)]) for j in jobs | |
| ]) | |
| scores = model.predict(feats) | |
| return [j for _, j in sorted(zip(scores, jobs), | |
| key=lambda x: x[0], reverse=True)] | |
| except Exception: | |
| return fifo_dispatch(jobs, t, zone_id) | |
| return dispatch | |
| def _run_one_preset(p, model) -> Dict[str, Any]: | |
| sim_kw = dict( | |
| base_arrival_rate=p.base_arrival_rate, breakdown_prob=p.breakdown_prob, | |
| batch_arrival_size=p.batch_arrival_size, lunch_penalty_factor=p.lunch_penalty_factor, | |
| job_type_frequencies=p.job_type_frequencies, | |
| due_date_tightness=p.due_date_tightness, | |
| processing_time_scale=p.processing_time_scale, | |
| ) | |
| fe = FeatureExtractor() | |
| base_fn = DISPATCH_FNS.get(p.favored_heuristic, fifo_dispatch) | |
| base_sim = WarehouseSimulator(seed=p.seed, heuristic_fn=base_fn, **sim_kw) | |
| base_metrics = base_sim.run(duration=600.0) | |
| sim_ref = [None] | |
| dispatch = _make_priority_dispatch(model, fe, sim_ref) | |
| dahs_sim = WarehouseSimulator(seed=p.seed, heuristic_fn=dispatch, | |
| feature_extractor=fe, **sim_kw) | |
| sim_ref[0] = dahs_sim | |
| dahs_metrics = dahs_sim.run(duration=600.0) | |
| return { | |
| "preset": p.name, | |
| "favored": p.favored_heuristic, | |
| "baseline_tardiness": float(base_metrics.total_tardiness), | |
| "dahs_tardiness": float(dahs_metrics.total_tardiness), | |
| "wins": float(dahs_metrics.total_tardiness) <= float(base_metrics.total_tardiness), | |
| } | |
| def verify_presets(model) -> Tuple[int, List[Dict[str, Any]]]: | |
| logger.info("VERIFY A: preset benchmark on candidate ...") | |
| rows: List[Dict[str, Any]] = [] | |
| for p in PRESETS: | |
| rows.append(_run_one_preset(p, model)) | |
| n_wins = sum(1 for r in rows if r["wins"]) | |
| logger.info("VERIFY A: %d/%d preset wins", n_wins, len(rows)) | |
| for r in rows: | |
| mark = "OK" if r["wins"] else "LOSS" | |
| logger.info(" [%s] %-22s base=%.0f dahs=%.0f", | |
| mark, r["preset"], r["baseline_tardiness"], r["dahs_tardiness"]) | |
| return n_wins, rows | |
| # ============================================================================ | |
| # Step 5: random-seed benchmark (uses candidate model) | |
| # ============================================================================ | |
| def _run_one_seed_all(seed: int, model) -> Dict[str, Any]: | |
| """Run all 6 baselines + DAHS-priority on one seed; return tardiness dict.""" | |
| fe = FeatureExtractor() | |
| out = {"seed": seed} | |
| # baselines | |
| for name, fn in DISPATCH_FNS.items(): | |
| sim = WarehouseSimulator(seed=seed, heuristic_fn=fn) | |
| m = sim.run(duration=600.0) | |
| out[name] = float(m.total_tardiness) | |
| # candidate priority | |
| sim_ref = [None] | |
| dispatch = _make_priority_dispatch(model, fe, sim_ref) | |
| sim = WarehouseSimulator(seed=seed, heuristic_fn=dispatch, feature_extractor=fe) | |
| sim_ref[0] = sim | |
| m = sim.run(duration=600.0) | |
| out["dahs_priority"] = float(m.total_tardiness) | |
| return out | |
| def verify_random(model, n_seeds: int = 20) -> Tuple[int, List[Dict[str, Any]]]: | |
| logger.info("VERIFY B: random-seed benchmark on %d seeds ...", n_seeds) | |
| rows: List[Dict[str, Any]] = [] | |
| for s in range(n_seeds): | |
| rows.append(_run_one_seed_all(s, model)) | |
| if (s + 1) % 5 == 0: | |
| logger.info(" random verify: %d/%d done", s + 1, n_seeds) | |
| n_wins = 0 | |
| for r in rows: | |
| baseline_tards = [r[h] for h in DISPATCH_FNS.keys()] | |
| if r["dahs_priority"] <= min(baseline_tards) + 1e-6: | |
| n_wins += 1 | |
| r["wins"] = True | |
| else: | |
| r["wins"] = False | |
| logger.info("VERIFY B: %d/%d random-seed wins", n_wins, n_seeds) | |
| return n_wins, rows | |
| # ============================================================================ | |
| # Main pipeline | |
| # ============================================================================ | |
| def main() -> None: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--preset-floor", type=int, default=5, | |
| help="Minimum preset wins required to promote (current=5)") | |
| parser.add_argument("--random-floor", type=int, default=18, | |
| help="Minimum random-seed wins (out of 20) required to promote") | |
| parser.add_argument("--skip-data-gen", action="store_true", | |
| help="Reuse existing augmented dataset if present") | |
| args = parser.parse_args() | |
| print("\n" + "=" * 88) | |
| print(" FOOLPROOF RETRAIN PIPELINE") | |
| print("=" * 88) | |
| print(f" Preset floor: >= {args.preset_floor}/7 wins") | |
| print(f" Random floor: >= {args.random_floor}/20 wins") | |
| print(f" Live model: {LIVE_MODEL}") | |
| print(f" Backup will be: {BACKUP_MODEL}") | |
| print("=" * 88 + "\n") | |
| if not LIVE_MODEL.exists(): | |
| raise SystemExit(f"No live model at {LIVE_MODEL}; nothing to back up.") | |
| # Step 0: Backup | |
| logger.info("STEP 0: Backing up live model -> %s", BACKUP_MODEL) | |
| shutil.copy2(LIVE_MODEL, BACKUP_MODEL) | |
| # Step 1+2: Augment data | |
| if args.skip_data_gen and AUG_DATA.exists(): | |
| logger.info("STEP 1+2: Reusing existing %s", AUG_DATA) | |
| df_aug = pd.read_csv(AUG_DATA) | |
| else: | |
| logger.info("STEP 1+2: Generating augmented dataset") | |
| df_aug = generate_augmented_dataset() | |
| # Step 3: Train candidate | |
| logger.info("STEP 3: Training candidate GBR") | |
| train_candidate(df_aug) | |
| candidate = joblib.load(CANDIDATE_MODEL) | |
| # Step 4 + 5: Verify | |
| preset_wins, preset_rows = verify_presets(candidate) | |
| random_wins, random_rows = verify_random(candidate, n_seeds=20) | |
| # Step 6: Promote / rollback | |
| print("\n" + "=" * 88) | |
| print(" GATE DECISION") | |
| print("-" * 88) | |
| print(f" Preset wins: {preset_wins}/7 (floor: {args.preset_floor})") | |
| print(f" Random wins: {random_wins}/20 (floor: {args.random_floor})") | |
| promote = (preset_wins >= args.preset_floor) and (random_wins >= args.random_floor) | |
| gate_report = { | |
| "preset_wins": preset_wins, | |
| "random_wins": random_wins, | |
| "preset_floor": args.preset_floor, | |
| "random_floor": args.random_floor, | |
| "promoted": promote, | |
| "preset_rows": preset_rows, | |
| "random_rows": random_rows, | |
| } | |
| (RESULTS_DIR / "foolproof_retrain_report.json").write_text( | |
| json.dumps(gate_report, indent=2) | |
| ) | |
| if promote: | |
| os.replace(str(CANDIDATE_MODEL), str(LIVE_MODEL)) | |
| # Update preset_benchmark.json with new numbers | |
| out = [] | |
| for r in preset_rows: | |
| base = r["baseline_tardiness"] | |
| dahs = r["dahs_tardiness"] | |
| imp = (base - dahs) / base * 100.0 if base > 0 else 0.0 | |
| out.append({ | |
| "preset": r["preset"], | |
| "favored": r["favored"], | |
| "baseline_tardiness": round(base, 2), | |
| "dahs_tardiness": round(dahs, 2), | |
| "improvement_pct": round(imp, 2), | |
| "dahs_wins": r["wins"], | |
| }) | |
| (RESULTS_DIR / "preset_benchmark.json").write_text(json.dumps(out, indent=2)) | |
| print(" RESULT: PROMOTED. New model is live.") | |
| print(f" Old model preserved at: {BACKUP_MODEL}") | |
| else: | |
| try: | |
| CANDIDATE_MODEL.unlink() | |
| except FileNotFoundError: | |
| pass | |
| print(" RESULT: REJECTED. Live model unchanged.") | |
| print(f" Reason:") | |
| if preset_wins < args.preset_floor: | |
| print(f" - preset_wins={preset_wins} < floor={args.preset_floor}") | |
| if random_wins < args.random_floor: | |
| print(f" - random_wins={random_wins} < floor={args.random_floor}") | |
| print("=" * 88 + "\n") | |
| sys.exit(0 if promote else 1) | |
| if __name__ == "__main__": | |
| main() | |