ligaments-dev commited on
Commit
28a17b6
·
verified ·
1 Parent(s): d7c4f5a

Upload grpo_training.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. grpo_training.py +18 -0
grpo_training.py CHANGED
@@ -60,9 +60,27 @@ config = GRPOConfig(
60
  project="ligaments-sec-alignment",
61
  )
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  # Initialize GRPO trainer
64
  trainer = GRPOTrainer(
65
  model=model_name,
 
66
  train_dataset=train_dataset,
67
  eval_dataset=eval_dataset,
68
  peft_config=LoraConfig(
 
60
  project="ligaments-sec-alignment",
61
  )
62
 
63
+ # Define reward function for GRPO
64
+ def preference_reward_func(samples):
65
+ """Simple reward function based on response length preference"""
66
+ rewards = []
67
+ for sample in samples:
68
+ # Prefer shorter, more concise responses (addressing verbosity issue)
69
+ response_length = len(sample["response"].split())
70
+ # Reward shorter responses (up to a reasonable length)
71
+ if response_length < 50:
72
+ reward = 1.0
73
+ elif response_length < 100:
74
+ reward = 0.5
75
+ else:
76
+ reward = 0.0 # Penalize overly verbose responses
77
+ rewards.append(reward)
78
+ return rewards
79
+
80
  # Initialize GRPO trainer
81
  trainer = GRPOTrainer(
82
  model=model_name,
83
+ reward_funcs=[preference_reward_func],
84
  train_dataset=train_dataset,
85
  eval_dataset=eval_dataset,
86
  peft_config=LoraConfig(