Rithwik Ravi commited on
Commit
3c20800
·
1 Parent(s): 80b34d1

chore(training): optimize GRPO params for sub-4h target on RTX 4070

Browse files
Files changed (1) hide show
  1. src/rl/train_grpo.py +4 -4
src/rl/train_grpo.py CHANGED
@@ -133,10 +133,10 @@ def train():
133
  output_dir="outputs",
134
  learning_rate=1e-5,
135
  per_device_train_batch_size=4, # Pushing 8GB VRAM to 95% util
136
- gradient_accumulation_steps=4, # Effective batch size 16
137
- num_generations=4, # Fix: Reduce from 8 to 4 to prevent OOM / Shared Memory Swapping
138
- max_steps=250, # 30-45 mins on RTX 4070
139
- max_completion_length=1024, # Fix: Prevent 256 token cutoff
140
  max_prompt_length=512,
141
  logging_steps=1,
142
  save_steps=50,
 
133
  output_dir="outputs",
134
  learning_rate=1e-5,
135
  per_device_train_batch_size=4, # Pushing 8GB VRAM to 95% util
136
+ gradient_accumulation_steps=8, # VRAM efficiency
137
+ num_generations=4, # Optimize sampling
138
+ max_steps=120, # Sub-4 hour target
139
+ max_completion_length=512, # Shorten generations
140
  max_prompt_length=512,
141
  logging_steps=1,
142
  save_steps=50,