Prajwal782007 commited on
Commit
32d5b8f
·
1 Parent(s): 26e9b86

feat: implement Unsloth GRPO training script with environment-based reward tracking and balanced dataset generation

Browse files
Files changed (1) hide show
  1. scripts/train_unsloth.py +34 -18
scripts/train_unsloth.py CHANGED
@@ -14,6 +14,7 @@ Fixed:
14
  """
15
 
16
  import argparse
 
17
  import json
18
  import math
19
  import os
@@ -667,24 +668,39 @@ def main():
667
  })
668
  print(f"✅ Dataset ready: {len(dataset)} training prompts")
669
 
670
- training_args = GRPOConfig(
671
- output_dir=args.output_dir,
672
- num_train_epochs=args.epochs,
673
- max_steps=args.max_steps,
674
- per_device_train_batch_size=1,
675
- gradient_accumulation_steps=4,
676
- num_generations=4, # FIXED: was 2, need 4 for variance
677
- max_prompt_length=256,
678
- max_completion_length=128,
679
- learning_rate=5e-6, # FIXED: was 5e-5, too high
680
- lr_scheduler_type="cosine",
681
- warmup_ratio=0.1,
682
- logging_steps=5,
683
- save_steps=100,
684
- fp16=True,
685
- report_to="none",
686
- seed=42,
687
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
688
 
689
  reward_fn = GridMindRewardFn(args.env_url, num_steps=8)
690
 
 
14
  """
15
 
16
  import argparse
17
+ import inspect
18
  import json
19
  import math
20
  import os
 
668
  })
669
  print(f"✅ Dataset ready: {len(dataset)} training prompts")
670
 
671
+ requested_training_args = {
672
+ "output_dir": args.output_dir,
673
+ "num_train_epochs": args.epochs,
674
+ "max_steps": args.max_steps,
675
+ "per_device_train_batch_size": 1,
676
+ "gradient_accumulation_steps": 4,
677
+ "num_generations": 4, # FIXED: was 2, need 4 for variance
678
+ "max_prompt_length": 256,
679
+ "max_completion_length": 128,
680
+ "max_new_tokens": 128,
681
+ "learning_rate": 5e-6, # FIXED: was 5e-5, too high
682
+ "lr_scheduler_type": "cosine",
683
+ "warmup_ratio": 0.1,
684
+ "logging_steps": 5,
685
+ "save_steps": 100,
686
+ "fp16": True,
687
+ "report_to": "none",
688
+ "seed": 42,
689
+ }
690
+ grpo_config_params = set(inspect.signature(GRPOConfig.__init__).parameters) - {"self"}
691
+ training_arg_kwargs = {
692
+ key: value for key, value in requested_training_args.items()
693
+ if key in grpo_config_params
694
+ }
695
+ if "max_completion_length" in training_arg_kwargs and "max_new_tokens" in training_arg_kwargs:
696
+ training_arg_kwargs.pop("max_new_tokens")
697
+ skipped_training_args = [
698
+ key for key in requested_training_args
699
+ if key not in grpo_config_params
700
+ ]
701
+ if skipped_training_args:
702
+ print(f"Skipping unsupported GRPOConfig args: {skipped_training_args}")
703
+ training_args = GRPOConfig(**training_arg_kwargs)
704
 
705
  reward_fn = GridMindRewardFn(args.env_url, num_steps=8)
706