div18 commited on
Commit
86112b1
·
1 Parent(s): b1e6564

don't OOM

Browse files
Files changed (2) hide show
  1. training/config.yaml +1 -1
  2. training/launch_train.py +1 -1
training/config.yaml CHANGED
@@ -46,7 +46,7 @@ loss_type: "reinforce_baseline" # reinforce | reinforce_baseline | grpo
46
  num_episodes_per_iteration: 4 # Safe now: max_seq_len=512 + loss_batch_size=8 + CPU offload
47
  num_iterations: 500 # Total training iterations
48
  parallel_episodes: true # Batch generation across episodes (10x faster)
49
- loss_batch_size: 8 # Safe now: for_training removed, ~19 GiB freed
50
  learning_rate: 2.0e-4
51
  per_device_train_batch_size: 2 # A10G can handle 2 with seq_len=1024
52
  gradient_accumulation_steps: 4 # Effective batch = 2*4 = 8 transitions
 
46
  num_episodes_per_iteration: 4 # Safe now: max_seq_len=512 + loss_batch_size=8 + CPU offload
47
  num_iterations: 500 # Total training iterations
48
  parallel_episodes: true # Batch generation across episodes (10x faster)
49
+ loss_batch_size: 4 # Safe now: for_training removed, ~19 GiB freed
50
  learning_rate: 2.0e-4
51
  per_device_train_batch_size: 2 # A10G can handle 2 with seq_len=1024
52
  gradient_accumulation_steps: 4 # Effective batch = 2*4 = 8 transitions
training/launch_train.py CHANGED
@@ -78,7 +78,7 @@ DOCKER_IMAGE = "pytorch/pytorch:2.10.0-cuda12.6-cudnn9-devel"
78
 
79
  DEFAULT_NUM_ITERATIONS = 500
80
  DEFAULT_NUM_EPISODES = 4
81
- DEFAULT_MAX_STEPS = 40
82
  DEFAULT_EVAL_INTERVAL = 50
83
  DEFAULT_CHECKPOINT_INTERVAL = 25
84
  DEFAULT_PLOT_INTERVAL = 25
 
78
 
79
  DEFAULT_NUM_ITERATIONS = 500
80
  DEFAULT_NUM_EPISODES = 4
81
+ DEFAULT_MAX_STEPS = 20
82
  DEFAULT_EVAL_INTERVAL = 50
83
  DEFAULT_CHECKPOINT_INTERVAL = 25
84
  DEFAULT_PLOT_INTERVAL = 25