| # jaxgmg2_shared_init | |
| 20 RL agent checkpoints studying the effect of shared initialization. Two base models | |
| (`al_1.0_g_0.98_id_19_seed_981019` and `al_1.0_g_0.98_id_27_seed_981027` from | |
| [jaxgmg2_3phase_optim_state](https://huggingface.co/timaeus/jaxgmg2_3phase_optim_state)) | |
| are each used as a shared starting point, then independently continued from checkpoint 0 | |
| (fresh optimizer state) with alpha=1.0 across 10 different random seeds each. | |
| **WandB:** https://wandb.ai/devinterp/jaxgmg2_shared_init | |
| ## Sweep | |
| 2 base models x 10 seeds (30-39) = 20 total runs. | |
| Base models resumed: | |
| - `jaxgmg2_3phase_optim_state/al_1.0_g_0.98_id_19_seed_981019` | |
| - `jaxgmg2_3phase_optim_state/al_1.0_g_0.98_id_27_seed_981027` | |
| ## Shared Hyperparams | |
| ``` | |
| rl_action=train | |
| alpha=1.0 | |
| discount_rate=0.98 | |
| lr=5e-05 | |
| num_total_env_steps=1351680000 | |
| num_rollout_steps=64 | |
| num_levels=9600 | |
| cheese_loc=any | |
| env_layout=open | |
| env_size=13 | |
| resume_id=0 | |
| resume_optim=False | |
| grad_acc_per_chunk=4 | |
| log_optimizer_state=True | |
| eval_schedule=0:1,250:2,500:5,2000:10 | |
| f_str_ckpt=al_1.0_g_0.98_id_{run_id}_shared_init_seed_{seed} | |
| ckpt_dir=jaxgmg2_shared_init | |
| wandb_project=jaxgmg2_shared_init | |
| use_wandb=True | |
| use_hf=True | |
| ``` | |
| ## Naming Schema | |
| Checkpoints are named `al_1.0_g_0.98_id_{run_id}_shared_init_seed_{seed}`. | |
| ## Reproduced with | |
| See [`train.yaml`](./train.yaml) in this repository. Run with: | |
| ```bash | |
| timaeus run train.yaml | |
| ``` | |
| from the [timaeus monorepo](https://github.com/timaeus-research/timaeus). | |