Spaces:
Sleeping
Sleeping
Commit ·
32d5b8f
1
Parent(s): 26e9b86
feat: implement Unsloth GRPO training script with environment-based reward tracking and balanced dataset generation
Browse files- scripts/train_unsloth.py +34 -18
scripts/train_unsloth.py
CHANGED
|
@@ -14,6 +14,7 @@ Fixed:
|
|
| 14 |
"""
|
| 15 |
|
| 16 |
import argparse
|
|
|
|
| 17 |
import json
|
| 18 |
import math
|
| 19 |
import os
|
|
@@ -667,24 +668,39 @@ def main():
|
|
| 667 |
})
|
| 668 |
print(f"✅ Dataset ready: {len(dataset)} training prompts")
|
| 669 |
|
| 670 |
-
|
| 671 |
-
output_dir
|
| 672 |
-
num_train_epochs
|
| 673 |
-
max_steps
|
| 674 |
-
per_device_train_batch_size
|
| 675 |
-
gradient_accumulation_steps
|
| 676 |
-
num_generations
|
| 677 |
-
max_prompt_length
|
| 678 |
-
max_completion_length
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 688 |
|
| 689 |
reward_fn = GridMindRewardFn(args.env_url, num_steps=8)
|
| 690 |
|
|
|
|
| 14 |
"""
|
| 15 |
|
| 16 |
import argparse
|
| 17 |
+
import inspect
|
| 18 |
import json
|
| 19 |
import math
|
| 20 |
import os
|
|
|
|
| 668 |
})
|
| 669 |
print(f"✅ Dataset ready: {len(dataset)} training prompts")
|
| 670 |
|
| 671 |
+
requested_training_args = {
|
| 672 |
+
"output_dir": args.output_dir,
|
| 673 |
+
"num_train_epochs": args.epochs,
|
| 674 |
+
"max_steps": args.max_steps,
|
| 675 |
+
"per_device_train_batch_size": 1,
|
| 676 |
+
"gradient_accumulation_steps": 4,
|
| 677 |
+
"num_generations": 4, # FIXED: was 2, need 4 for variance
|
| 678 |
+
"max_prompt_length": 256,
|
| 679 |
+
"max_completion_length": 128,
|
| 680 |
+
"max_new_tokens": 128,
|
| 681 |
+
"learning_rate": 5e-6, # FIXED: was 5e-5, too high
|
| 682 |
+
"lr_scheduler_type": "cosine",
|
| 683 |
+
"warmup_ratio": 0.1,
|
| 684 |
+
"logging_steps": 5,
|
| 685 |
+
"save_steps": 100,
|
| 686 |
+
"fp16": True,
|
| 687 |
+
"report_to": "none",
|
| 688 |
+
"seed": 42,
|
| 689 |
+
}
|
| 690 |
+
grpo_config_params = set(inspect.signature(GRPOConfig.__init__).parameters) - {"self"}
|
| 691 |
+
training_arg_kwargs = {
|
| 692 |
+
key: value for key, value in requested_training_args.items()
|
| 693 |
+
if key in grpo_config_params
|
| 694 |
+
}
|
| 695 |
+
if "max_completion_length" in training_arg_kwargs and "max_new_tokens" in training_arg_kwargs:
|
| 696 |
+
training_arg_kwargs.pop("max_new_tokens")
|
| 697 |
+
skipped_training_args = [
|
| 698 |
+
key for key in requested_training_args
|
| 699 |
+
if key not in grpo_config_params
|
| 700 |
+
]
|
| 701 |
+
if skipped_training_args:
|
| 702 |
+
print(f"Skipping unsupported GRPOConfig args: {skipped_training_args}")
|
| 703 |
+
training_args = GRPOConfig(**training_arg_kwargs)
|
| 704 |
|
| 705 |
reward_fn = GridMindRewardFn(args.env_url, num_steps=8)
|
| 706 |
|