Disruption-System / scripts /run_pipeline.py
Vittal-M's picture
Upload 66 files
906e104 verified
#!/usr/bin/env python3
"""scripts/run_pipeline.py β€” DAHS_2 End-to-End Training Pipeline.
Steps:
1. Generate selector dataset (snapshot-fork)
2. Generate priority dataset
3. Train selector models (DT, RF, XGB)
4. Train priority predictor (GBR)
5. Run benchmark evaluation
Each step is followed by an *incremental* Hub snapshot so partial progress
survives even if the Space runtime is killed mid-pipeline.
"""
from __future__ import annotations
import argparse
import json
import logging
import os
import platform
import socket
import subprocess
import sys
import time
from datetime import datetime, timezone
from pathlib import Path
for _stream in ("stdout", "stderr"):
try:
getattr(sys, _stream).reconfigure(encoding="utf-8", errors="replace")
except Exception:
pass
ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(ROOT))
(ROOT / "logs").mkdir(exist_ok=True)
(ROOT / "data" / "raw").mkdir(parents=True, exist_ok=True)
(ROOT / "models").mkdir(exist_ok=True)
(ROOT / "results" / "plots").mkdir(parents=True, exist_ok=True)
_stream_handler = logging.StreamHandler()
_file_handler = logging.FileHandler(ROOT / "logs" / "pipeline.log", mode="a", encoding="utf-8")
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
handlers=[_stream_handler, _file_handler],
)
logger = logging.getLogger(__name__)
def step(n: int, label: str) -> None:
print(f"\n{'=' * 60}")
print(f" STEP {n}: {label}")
print(f"{'=' * 60}\n")
def _git_sha() -> str:
try:
out = subprocess.check_output(
["git", "rev-parse", "HEAD"], cwd=ROOT, stderr=subprocess.DEVNULL
)
return out.decode().strip()
except Exception:
return "unknown"
def _pip_freeze_to(path: Path) -> None:
try:
out = subprocess.check_output([sys.executable, "-m", "pip", "freeze"])
path.write_text(out.decode(), encoding="utf-8")
except Exception as e: # noqa: BLE001
logger.warning("pip freeze failed: %s", e)
def _write_run_manifest(args: argparse.Namespace, n_scenarios: int, n_eval_seeds: int) -> None:
manifest = {
"started_at": datetime.now(timezone.utc).isoformat(),
"git_sha": _git_sha(),
"host": socket.gethostname(),
"platform": platform.platform(),
"python": sys.version,
"cpu_count": os.cpu_count(),
"args": vars(args),
"n_scenarios": n_scenarios,
"n_eval_seeds": n_eval_seeds,
"env": {
"REPO_ID": os.environ.get("REPO_ID"),
"SPACE_ID": os.environ.get("SPACE_ID"),
"HF_TOKEN_set": bool(os.environ.get("HF_TOKEN")),
},
}
try:
import sklearn, xgboost, scipy, numpy, pandas # noqa: I001
manifest["versions"] = {
"sklearn": sklearn.__version__,
"xgboost": xgboost.__version__,
"scipy": scipy.__version__,
"numpy": numpy.__version__,
"pandas": pandas.__version__,
}
except Exception:
pass
(ROOT / "results" / "run_manifest.json").write_text(
json.dumps(manifest, indent=2), encoding="utf-8"
)
_pip_freeze_to(ROOT / "results" / "pip_freeze.txt")
def main() -> None:
parser = argparse.ArgumentParser(description="DAHS_2 Training Pipeline")
parser.add_argument("--quick", action="store_true", help="Quick smoke test")
parser.add_argument("--eval-only", action="store_true", help="Skip training, run eval only")
parser.add_argument("--no-eval", action="store_true", help="Skip benchmark evaluation")
parser.add_argument("--workers", type=int, default=4, help="Parallel workers")
parser.add_argument("--scenarios", type=int, default=None, help="Override scenario count")
parser.add_argument("--eval-seeds", type=int, default=None, help="Override eval seed count")
parser.add_argument("--snapshot-every-step", action="store_true", default=True,
help="Push to HF Hub after each pipeline step")
args = parser.parse_args()
n_scenarios = args.scenarios or (50 if args.quick else 1000)
n_eval_seeds = args.eval_seeds or (20 if args.quick else 1000)
n_workers = args.workers
t_start = time.time()
# Bulletproof Hub persistence β€” no-op if env vars unset (local runs).
from src.hf_persistence import from_env
persistor = from_env(require=False)
persistor.install_signal_handlers()
persistor.install_atexit()
persistor.start_periodic(interval_seconds=300) # every 5 min
_write_run_manifest(args, n_scenarios, n_eval_seeds)
persistor.snapshot("results", msg="run_start manifest")
print("\n" + "=" * 60)
print(" DAHS 2.0 β€” Full Training & Evaluation Pipeline")
print(f" Scenarios: {n_scenarios} | Eval seeds: {n_eval_seeds} | Workers: {n_workers}")
print("=" * 60)
if not args.eval_only:
# Step 1
step(1, "Snapshot-Fork Selector Dataset")
from src.data_generator import generate_selector_dataset
t = time.time()
df = generate_selector_dataset(n_scenarios=n_scenarios, n_workers=n_workers)
logger.info("Selector dataset: %d rows in %.1fs", len(df), time.time() - t)
print(f" βœ“ Selector dataset: {len(df):,} rows")
persistor.snapshot("data", msg="selector_dataset")
# Step 2
step(2, "Priority Predictor Dataset")
from src.data_generator import generate_priority_dataset
t = time.time()
priority_df = generate_priority_dataset(
n_scenarios=min(n_scenarios * 5, 5_000),
n_points_per=10,
n_workers=n_workers,
)
logger.info("Priority dataset: %d rows in %.1fs", len(priority_df), time.time() - t)
print(f" βœ“ Priority dataset: {len(priority_df):,} rows")
persistor.snapshot("data", msg="priority_dataset")
# Step 3
step(3, "Train Selector Models (DT + RF + XGB)")
from src.train_selector import train_selector_models
t = time.time()
selector_models = train_selector_models()
logger.info("Selector training done in %.1fs", time.time() - t)
print(f" βœ“ Trained: {list(selector_models.keys())}")
persistor.snapshot("models", msg="selector_models")
persistor.snapshot("results", msg="selector_metrics")
# Step 4
step(4, "Train Priority Predictor (GBR)")
from src.train_priority import train_priority_model
t = time.time()
gbr = train_priority_model()
logger.info("Priority training done in %.1fs", time.time() - t)
print(" βœ“ Priority GBR trained")
persistor.snapshot("models", msg="priority_model")
persistor.snapshot("results", msg="priority_metrics")
# Step 5
if not args.no_eval:
step(5, "Benchmark Evaluation")
from src.evaluator import run_full_evaluation
t = time.time()
eval_seeds = list(range(99000, 99000 + n_eval_seeds))
results = run_full_evaluation(seeds=eval_seeds, n_workers=n_workers)
logger.info("Evaluation done: %d seeds in %.1fs", n_eval_seeds, time.time() - t)
print(f" βœ“ Evaluation complete ({n_eval_seeds} seeds)")
persistor.snapshot("results", msg="evaluation")
bench_df = results["benchmark"]
if not bench_df.empty:
print("\n Performance Summary (mean total tardiness):")
for method in sorted(bench_df["method"].unique()):
mean_t = bench_df[bench_df["method"] == method]["total_tardiness"].mean()
print(f" {method:<22}: {mean_t:>8.1f}")
elapsed = time.time() - t_start
print(f"\n Pipeline complete in {elapsed / 60:.1f} minutes.")
print(f" Artifacts: {ROOT / 'models'}, {ROOT / 'results'}, {ROOT / 'data'}")
# Final consolidated snapshot
persistor.stop_periodic()
persistor.snapshot(msg=f"pipeline_complete_{int(elapsed)}s")
if __name__ == "__main__":
main()