AbeBhatti commited on
Commit
d39a5d1
·
1 Parent(s): afd245f
Files changed (1) hide show
  1. ppo_train.py +26 -8
ppo_train.py CHANGED
@@ -88,15 +88,17 @@ def main() -> None:
88
  dataset = Dataset.from_dict({"prompt": prompts, "reward": base_rewards})
89
  print(f"Dataset size: {len(dataset)} examples")
90
 
91
- # Custom reward function for GRPO: reward model only (no random env noise).
92
  def reward_fn(completions: List[str], prompts: List[str], **kwargs) -> List[float]:
93
  """
94
- Score each completion using the reward model encoder + head only.
95
- Deterministic and low-noise compared to environment-based rewards.
96
  """
97
  scores: List[float] = []
98
- for i, completion in enumerate(completions):
99
  text = completion.strip()
 
 
 
100
  inputs = reward_tokenizer(
101
  text,
102
  return_tensors="pt",
@@ -108,10 +110,26 @@ def main() -> None:
108
 
109
  with torch.no_grad():
110
  hidden = reward_model.encoder(**inputs).last_hidden_state[:, 0, :]
111
- score = reward_model.head(hidden).squeeze().item()
112
-
113
- print(f"Completion {i}: score={score:.4f} | text={completion[:50]}")
114
- scores.append(torch.tensor(score, dtype=torch.float32))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  return scores
116
 
117
  # GRPO configuration: small batch, multiple generations per prompt.
 
88
  dataset = Dataset.from_dict({"prompt": prompts, "reward": base_rewards})
89
  print(f"Dataset size: {len(dataset)} examples")
90
 
91
+ # Custom reward function for GRPO: reward model + repetition & length penalties.
92
  def reward_fn(completions: List[str], prompts: List[str], **kwargs) -> List[float]:
93
  """
94
+ Score each completion with reward model, then apply repetition and length penalties.
 
95
  """
96
  scores: List[float] = []
97
+ for completion in completions:
98
  text = completion.strip()
99
+ words = text.split()
100
+
101
+ # Reward model score
102
  inputs = reward_tokenizer(
103
  text,
104
  return_tensors="pt",
 
110
 
111
  with torch.no_grad():
112
  hidden = reward_model.encoder(**inputs).last_hidden_state[:, 0, :]
113
+ rm_score = reward_model.head(hidden).squeeze().item()
114
+
115
+ # Repetition penalty unique words ratio
116
+ if len(words) > 0:
117
+ unique_ratio = len(set(words)) / len(words)
118
+ else:
119
+ unique_ratio = 0.0
120
+
121
+ if unique_ratio < 0.3:
122
+ repetition_penalty = -2.0
123
+ elif unique_ratio < 0.5:
124
+ repetition_penalty = -0.5
125
+ else:
126
+ repetition_penalty = 0.0
127
+
128
+ # Penalty for very short completions
129
+ length_penalty = -0.5 if len(words) < 5 else 0.0
130
+
131
+ combined = float(rm_score) + repetition_penalty + length_penalty
132
+ scores.append(torch.tensor(combined, dtype=torch.float32))
133
  return scores
134
 
135
  # GRPO configuration: small batch, multiple generations per prompt.