Rithwik Ravi commited on
Commit ·
3c20800
1
Parent(s): 80b34d1
chore(training): optimize GRPO params for sub-4h target on RTX 4070
Browse files- src/rl/train_grpo.py +4 -4
src/rl/train_grpo.py
CHANGED
|
@@ -133,10 +133,10 @@ def train():
|
|
| 133 |
output_dir="outputs",
|
| 134 |
learning_rate=1e-5,
|
| 135 |
per_device_train_batch_size=4, # Pushing 8GB VRAM to 95% util
|
| 136 |
-
gradient_accumulation_steps=
|
| 137 |
-
num_generations=4, #
|
| 138 |
-
max_steps=
|
| 139 |
-
max_completion_length=
|
| 140 |
max_prompt_length=512,
|
| 141 |
logging_steps=1,
|
| 142 |
save_steps=50,
|
|
|
|
| 133 |
output_dir="outputs",
|
| 134 |
learning_rate=1e-5,
|
| 135 |
per_device_train_batch_size=4, # Pushing 8GB VRAM to 95% util
|
| 136 |
+
gradient_accumulation_steps=8, # VRAM efficiency
|
| 137 |
+
num_generations=4, # Optimize sampling
|
| 138 |
+
max_steps=120, # Sub-4 hour target
|
| 139 |
+
max_completion_length=512, # Shorten generations
|
| 140 |
max_prompt_length=512,
|
| 141 |
logging_steps=1,
|
| 142 |
save_steps=50,
|