| # jaxgmg2_3phase_optim_state |
| |
| 32 RL agent checkpoints trained on the JaxGMG maze environment with discount_rate=0.98, |
| with full optimizer state saved at each checkpoint. Primary runs for the 3-phase training regime analysis. |
| Similar to `jaxgmg2_3phase_unique` but with optimizer state logging enabled. |
|
|
| - 16 runs with alpha=0.6 (run_id 15-30) |
| - 16 runs with alpha=1.0 (run_id 15-30) |
|
|
| **WandB:** https://wandb.ai/devinterp/jaxgmg2_3phase_optim_state |
| |
| ## Sweep |
| |
| run_id sweep: 15-30 for each alpha value. Seed is derived from run_id via: |
| `seed = int(discount_rate*100)*10000 + int(alpha*10)*100 + run_id` |
| e.g. alpha=0.6, run_id=15 -> seed=980615; alpha=1.0, run_id=15 -> seed=981015. |
| |
| ## Shared Hyperparams |
| |
| ``` |
| rl_action=train |
| alpha=0.6 or 1.0 |
| discount_rate=0.98 |
| lr=5e-05 |
| num_total_env_steps=10000000000 |
| num_rollout_steps=64 |
| num_levels=9600 |
| cheese_loc=any |
| env_layout=open |
| env_size=13 |
| mask_type=first_episode |
| use_prev_action=False |
| grad_acc_per_chunk=5 |
| log_optimizer_state=True |
| eval_schedule=0:1,250:2,500:5,2000:10 |
| seed_formula={int(discount_rate*100):02d}{int(alpha*10):02d}{run_id:02d} |
| f_str_ckpt=al_{alpha}_g_{discount_rate}_id_{run_id}_seed_{seed} |
| ckpt_dir=jaxgmg2_3phase_optim_state |
| wandb_project=jaxgmg2_3phase_optim_state |
| use_wandb=True |
| use_hf=True |
| ``` |
| |
| ## Naming Schema |
| |
| Checkpoints are named `al_{alpha}_g_0.98_id_{run_id}_seed_{seed}`. |
| |
| ## Reproduced with |
| |
| See [`train.yaml`](./train.yaml) in this repository. Run with: |
| |
| ```bash |
| make run projects/rl/experiments/al_0.6_g_0.98/jobs/train_optim_state.yaml |
| ``` |
| |
| from the [timaeus monorepo](https://github.com/timaeus-research/timaeus). |
| |