Spaces:
Runtime error
Runtime error
| """Local-machine bandit training. Mirrors hf/train_on_hf.py but with smaller defaults.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import sys | |
| from pathlib import Path | |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "hf")) | |
| from train_on_hf import _gather_instances, _train_one_pass # noqa: E402 | |
| from dash_jsp.bandit.linucb import LinUCBDispatcher | |
| from dash_jsp.bandit.thompson import ThompsonDispatcher | |
| def main() -> None: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--data-dir", default="data") | |
| parser.add_argument("--output-dir", default="models") | |
| parser.add_argument("--benchmarks", nargs="+", default=["lawrence"], | |
| choices=["taillard", "lawrence", "dmu"]) | |
| parser.add_argument("--epochs", type=int, default=2) | |
| parser.add_argument("--alpha", type=float, default=1.0) | |
| parser.add_argument("--seed", type=int, default=42) | |
| args = parser.parse_args() | |
| instances = _gather_instances(args.benchmarks, Path(args.data_dir)) | |
| out_dir = Path(args.output_dir) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| linucb = LinUCBDispatcher(alpha=args.alpha, rng_seed=args.seed) | |
| thompson = ThompsonDispatcher(rng_seed=args.seed) | |
| log = {"epochs": [], "benchmarks": args.benchmarks, "n_instances": len(instances)} | |
| for ep in range(args.epochs): | |
| lin = _train_one_pass(linucb, instances) | |
| ts = _train_one_pass(thompson, instances) | |
| log["epochs"].append({"epoch": ep, "linucb": lin, "thompson": ts}) | |
| print(f"epoch {ep}: linucb pulls={lin['n_arm_pulls']} thompson pulls={ts['n_arm_pulls']}") | |
| linucb.save(str(out_dir / "bandit_linucb.npz")) | |
| thompson.save(str(out_dir / "bandit_thompson.npz")) | |
| (out_dir / "training_log.json").write_text(json.dumps(log, indent=2)) | |
| print(f"Saved to {out_dir}/") | |
| if __name__ == "__main__": | |
| main() | |