David Quarel
Add README.md and train.yaml
0be7032
parameters:
project_name: jaxgmg2_3phase_optim_state_alpha1
action: rl
rl_action: train
# Learning — alpha changed to 1.0 from base model's 0.6
lr: 5e-5
alpha: 1.0
discount_rate: 0.98
cheese_loc: any
env_layout: open
mask_type: first_episode
use_prev_action: false
# Training scale
num_total_env_steps: 7_372_800_000 # 12_000 grad steps * 64 rollout steps * 9600 levels
num_levels: 9600
grad_acc_per_chunk: 4
num_rollout_steps: 64
# Resume from jaxgmg2_3phase_optim_state checkpoint 3810, with optimizer state
resume: jaxgmg2_3phase_optim_state/al_0.6_g_0.98_id_17_seed_980617
resume_id: 3810
resume_optim: true
# Checkpointing
ckpt_dir: jaxgmg2_3phase_optim_state_alpha1
checkpoint: al_0.6_g_0.98_id_17_seed_980617_resume_alpha1
eval_schedule: "0:1,250:2,500:5,2000:10"
log_optimizer_state: true
deterministic: true
# Logging
use_wandb: true
use_hf: true
hf_user: urdshals
wandb_project: jaxgmg2_patt
sweep:
- seed: 42