Spaces:
Runtime error
Runtime error
AbeBhatti commited on
Commit ·
d39a5d1
1
Parent(s): afd245f
initial
Browse files- ppo_train.py +26 -8
ppo_train.py
CHANGED
|
@@ -88,15 +88,17 @@ def main() -> None:
|
|
| 88 |
dataset = Dataset.from_dict({"prompt": prompts, "reward": base_rewards})
|
| 89 |
print(f"Dataset size: {len(dataset)} examples")
|
| 90 |
|
| 91 |
-
# Custom reward function for GRPO: reward model
|
| 92 |
def reward_fn(completions: List[str], prompts: List[str], **kwargs) -> List[float]:
|
| 93 |
"""
|
| 94 |
-
Score each completion
|
| 95 |
-
Deterministic and low-noise compared to environment-based rewards.
|
| 96 |
"""
|
| 97 |
scores: List[float] = []
|
| 98 |
-
for
|
| 99 |
text = completion.strip()
|
|
|
|
|
|
|
|
|
|
| 100 |
inputs = reward_tokenizer(
|
| 101 |
text,
|
| 102 |
return_tensors="pt",
|
|
@@ -108,10 +110,26 @@ def main() -> None:
|
|
| 108 |
|
| 109 |
with torch.no_grad():
|
| 110 |
hidden = reward_model.encoder(**inputs).last_hidden_state[:, 0, :]
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
return scores
|
| 116 |
|
| 117 |
# GRPO configuration: small batch, multiple generations per prompt.
|
|
|
|
| 88 |
dataset = Dataset.from_dict({"prompt": prompts, "reward": base_rewards})
|
| 89 |
print(f"Dataset size: {len(dataset)} examples")
|
| 90 |
|
| 91 |
+
# Custom reward function for GRPO: reward model + repetition & length penalties.
|
| 92 |
def reward_fn(completions: List[str], prompts: List[str], **kwargs) -> List[float]:
|
| 93 |
"""
|
| 94 |
+
Score each completion with reward model, then apply repetition and length penalties.
|
|
|
|
| 95 |
"""
|
| 96 |
scores: List[float] = []
|
| 97 |
+
for completion in completions:
|
| 98 |
text = completion.strip()
|
| 99 |
+
words = text.split()
|
| 100 |
+
|
| 101 |
+
# Reward model score
|
| 102 |
inputs = reward_tokenizer(
|
| 103 |
text,
|
| 104 |
return_tensors="pt",
|
|
|
|
| 110 |
|
| 111 |
with torch.no_grad():
|
| 112 |
hidden = reward_model.encoder(**inputs).last_hidden_state[:, 0, :]
|
| 113 |
+
rm_score = reward_model.head(hidden).squeeze().item()
|
| 114 |
+
|
| 115 |
+
# Repetition penalty — unique words ratio
|
| 116 |
+
if len(words) > 0:
|
| 117 |
+
unique_ratio = len(set(words)) / len(words)
|
| 118 |
+
else:
|
| 119 |
+
unique_ratio = 0.0
|
| 120 |
+
|
| 121 |
+
if unique_ratio < 0.3:
|
| 122 |
+
repetition_penalty = -2.0
|
| 123 |
+
elif unique_ratio < 0.5:
|
| 124 |
+
repetition_penalty = -0.5
|
| 125 |
+
else:
|
| 126 |
+
repetition_penalty = 0.0
|
| 127 |
+
|
| 128 |
+
# Penalty for very short completions
|
| 129 |
+
length_penalty = -0.5 if len(words) < 5 else 0.0
|
| 130 |
+
|
| 131 |
+
combined = float(rm_score) + repetition_penalty + length_penalty
|
| 132 |
+
scores.append(torch.tensor(combined, dtype=torch.float32))
|
| 133 |
return scores
|
| 134 |
|
| 135 |
# GRPO configuration: small batch, multiple generations per prompt.
|