parameters: project_name: jaxgmg2_3phase_optim_state_alpha1 action: rl rl_action: train # Learning — alpha changed to 1.0 from base model's 0.6 lr: 5e-5 alpha: 1.0 discount_rate: 0.98 cheese_loc: any env_layout: open mask_type: first_episode use_prev_action: false # Training scale num_total_env_steps: 7_372_800_000 # 12_000 grad steps * 64 rollout steps * 9600 levels num_levels: 9600 grad_acc_per_chunk: 4 num_rollout_steps: 64 # Resume from jaxgmg2_3phase_optim_state checkpoint 3810, with optimizer state resume: jaxgmg2_3phase_optim_state/al_0.6_g_0.98_id_17_seed_980617 resume_id: 3810 resume_optim: true # Checkpointing ckpt_dir: jaxgmg2_3phase_optim_state_alpha1 checkpoint: al_0.6_g_0.98_id_17_seed_980617_resume_alpha1 eval_schedule: "0:1,250:2,500:5,2000:10" log_optimizer_state: true deterministic: true # Logging use_wandb: true use_hf: true hf_user: urdshals wandb_project: jaxgmg2_patt sweep: - seed: 42