File size: 6,086 Bytes
2850928
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212eb34
2850928
 
 
212eb34
2850928
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#!/usr/bin/env python3
"""
scripts/run_pipeline.py β€” DAHS_2 End-to-End Training Pipeline

Steps:
  1. Generate selector dataset (snapshot-fork, n_scenarios configurable)
  2. Generate priority dataset
  3. Train selector models (DT, RF, XGB)
  4. Train priority predictor (GBR)
  5. [Optional] Run benchmark evaluation (300 seeds)

Usage:
  python scripts/run_pipeline.py             # Full pipeline (1000 scenarios)
  python scripts/run_pipeline.py --quick     # Quick smoke test (50 scenarios, 20 seeds)
  python scripts/run_pipeline.py --eval-only # Run evaluation only (models must exist)
"""
from __future__ import annotations

import argparse
import logging
import sys
import time
from pathlib import Path

# Force UTF-8 stdout/stderr on Windows so unicode chars (βœ“, Γ—, β†’) don't
# crash the pipeline after hours of data generation.
for _stream in ("stdout", "stderr"):
    try:
        getattr(sys, _stream).reconfigure(encoding="utf-8", errors="replace")
    except Exception:
        pass

# Make sure src is importable from project root
ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(ROOT))

(ROOT / "logs").mkdir(exist_ok=True)
(ROOT / "data" / "raw").mkdir(parents=True, exist_ok=True)
(ROOT / "models").mkdir(exist_ok=True)
(ROOT / "results" / "plots").mkdir(parents=True, exist_ok=True)

_stream_handler = logging.StreamHandler()
_file_handler = logging.FileHandler(ROOT / "logs" / "pipeline.log", mode="a", encoding="utf-8")
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s %(message)s",
    handlers=[_stream_handler, _file_handler],
)
logger = logging.getLogger(__name__)


def step(n: int, label: str) -> None:
    print(f"\n{'=' * 60}")
    print(f"  STEP {n}: {label}")
    print(f"{'=' * 60}\n")


def main() -> None:
    parser = argparse.ArgumentParser(description="DAHS_2 Training Pipeline")
    parser.add_argument("--quick", action="store_true", help="Quick smoke test (50 scenarios, 20 eval seeds)")
    parser.add_argument("--eval-only", action="store_true", help="Skip training, run evaluation only")
    parser.add_argument("--no-eval", action="store_true", help="Skip benchmark evaluation")
    parser.add_argument("--workers", type=int, default=4, help="Number of parallel workers")
    parser.add_argument("--scenarios", type=int, default=None, help="Override number of scenarios")
    parser.add_argument("--eval-seeds", type=int, default=None, help="Override number of evaluation seeds")
    args = parser.parse_args()

    n_scenarios = args.scenarios or (50 if args.quick else 1000)
    n_eval_seeds = args.eval_seeds or (20 if args.quick else 1000)
    n_workers = args.workers

    t_start = time.time()

    print("\n" + "=" * 60)
    print("  DAHS 2.0 β€” Full Training & Evaluation Pipeline")
    print(f"  Scenarios: {n_scenarios} | Workers: {n_workers}")
    print("=" * 60)

    if not args.eval_only:
        # ── Step 1: Selector dataset ─────────────────────────────────
        step(1, "Snapshot-Fork Selector Dataset")
        from src.data_generator import generate_selector_dataset
        t = time.time()
        df = generate_selector_dataset(n_scenarios=n_scenarios, n_workers=n_workers)
        logger.info("Selector dataset: %d rows in %.1fs", len(df), time.time() - t)
        print(f"  βœ“ Selector dataset: {len(df):,} rows")

        # ── Step 2: Priority dataset ─────────────────────────────────
        step(2, "Priority Predictor Dataset")
        from src.data_generator import generate_priority_dataset
        t = time.time()
        priority_df = generate_priority_dataset(
            n_scenarios=min(n_scenarios * 5, 5_000),
            n_points_per=10,
            n_workers=n_workers,
        )
        logger.info("Priority dataset: %d rows in %.1fs", len(priority_df), time.time() - t)
        print(f"  βœ“ Priority dataset: {len(priority_df):,} rows")

        # ── Step 3: Train selectors ──────────────────────────────────
        step(3, "Train Selector Models (DT + RF + XGB)")
        from src.train_selector import train_selector_models
        t = time.time()
        selector_models = train_selector_models()
        logger.info("Selector training done in %.1fs", time.time() - t)
        print(f"  βœ“ Trained: {list(selector_models.keys())}")

        # ── Step 4: Train priority predictor ────────────────────────
        step(4, "Train Priority Predictor (GBR)")
        from src.train_priority import train_priority_model
        t = time.time()
        gbr = train_priority_model()
        logger.info("Priority training done in %.1fs", time.time() - t)
        print("  βœ“ Priority GBR trained")

    # ── Step 5: Benchmark evaluation ─────────────────────────────────
    if not args.no_eval:
        step(5, "Benchmark Evaluation")
        from src.evaluator import run_full_evaluation
        t = time.time()
        eval_seeds = list(range(99000, 99000 + n_eval_seeds))
        results = run_full_evaluation(seeds=eval_seeds, n_workers=n_workers)
        logger.info("Evaluation done: %d seeds in %.1fs", n_eval_seeds, time.time() - t)
        print(f"  βœ“ Evaluation complete ({n_eval_seeds} seeds)")

        # Print summary
        bench_df = results["benchmark"]
        if not bench_df.empty:
            print("\n  Performance Summary (mean total tardiness):")
            for method in sorted(bench_df["method"].unique()):
                mean_t = bench_df[bench_df["method"] == method]["total_tardiness"].mean()
                print(f"    {method:<20}: {mean_t:>8.1f}")

    elapsed = time.time() - t_start
    print(f"\n  Pipeline complete in {elapsed / 60:.1f} minutes.")
    print(f"  Artifacts saved to: {ROOT / 'models'} and {ROOT / 'results'}\n")


if __name__ == "__main__":
    main()