dash-jsp-trainer / scripts /train_bandit.py
Vittal-M's picture
Trainer Space: download -> train -> push -> sleep
52c82e4
"""Local-machine bandit training. Mirrors hf/train_on_hf.py but with smaller defaults."""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "hf"))
from train_on_hf import _gather_instances, _train_one_pass # noqa: E402
from dash_jsp.bandit.linucb import LinUCBDispatcher
from dash_jsp.bandit.thompson import ThompsonDispatcher
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--data-dir", default="data")
parser.add_argument("--output-dir", default="models")
parser.add_argument("--benchmarks", nargs="+", default=["lawrence"],
choices=["taillard", "lawrence", "dmu"])
parser.add_argument("--epochs", type=int, default=2)
parser.add_argument("--alpha", type=float, default=1.0)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
instances = _gather_instances(args.benchmarks, Path(args.data_dir))
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
linucb = LinUCBDispatcher(alpha=args.alpha, rng_seed=args.seed)
thompson = ThompsonDispatcher(rng_seed=args.seed)
log = {"epochs": [], "benchmarks": args.benchmarks, "n_instances": len(instances)}
for ep in range(args.epochs):
lin = _train_one_pass(linucb, instances)
ts = _train_one_pass(thompson, instances)
log["epochs"].append({"epoch": ep, "linucb": lin, "thompson": ts})
print(f"epoch {ep}: linucb pulls={lin['n_arm_pulls']} thompson pulls={ts['n_arm_pulls']}")
linucb.save(str(out_dir / "bandit_linucb.npz"))
thompson.save(str(out_dir / "bandit_thompson.npz"))
(out_dir / "training_log.json").write_text(json.dumps(log, indent=2))
print(f"Saved to {out_dir}/")
if __name__ == "__main__":
main()