"""Local-machine bandit training. Mirrors hf/train_on_hf.py but with smaller defaults.""" from __future__ import annotations import argparse import json import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "hf")) from train_on_hf import _gather_instances, _train_one_pass # noqa: E402 from dash_jsp.bandit.linucb import LinUCBDispatcher from dash_jsp.bandit.thompson import ThompsonDispatcher def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--data-dir", default="data") parser.add_argument("--output-dir", default="models") parser.add_argument("--benchmarks", nargs="+", default=["lawrence"], choices=["taillard", "lawrence", "dmu"]) parser.add_argument("--epochs", type=int, default=2) parser.add_argument("--alpha", type=float, default=1.0) parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() instances = _gather_instances(args.benchmarks, Path(args.data_dir)) out_dir = Path(args.output_dir) out_dir.mkdir(parents=True, exist_ok=True) linucb = LinUCBDispatcher(alpha=args.alpha, rng_seed=args.seed) thompson = ThompsonDispatcher(rng_seed=args.seed) log = {"epochs": [], "benchmarks": args.benchmarks, "n_instances": len(instances)} for ep in range(args.epochs): lin = _train_one_pass(linucb, instances) ts = _train_one_pass(thompson, instances) log["epochs"].append({"epoch": ep, "linucb": lin, "thompson": ts}) print(f"epoch {ep}: linucb pulls={lin['n_arm_pulls']} thompson pulls={ts['n_arm_pulls']}") linucb.save(str(out_dir / "bandit_linucb.npz")) thompson.save(str(out_dir / "bandit_thompson.npz")) (out_dir / "training_log.json").write_text(json.dumps(log, indent=2)) print(f"Saved to {out_dir}/") if __name__ == "__main__": main()