File size: 6,672 Bytes
dbc69f3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | from __future__ import annotations
from .registry import register_objective
@register_objective("token_grpo")
class TokenGRPOObjective:
name = "token_grpo"
def __init__(
self,
enable_length_penalty: bool = True,
enable_token_utilisation_reward: bool = False,
reward_mode: str = "additive",
correctness_weight: float = 1.0,
length_penalty_lambda: float = 1.0,
stage2_length_penalty_lambda: float | None = None,
stage2_start_epoch: float = 1.0,
length_penalty_interaction: str = "correctness_length",
strict_format_gate: bool = True,
non_strict_penalty: float = -1.0,
**kwargs,
):
del kwargs
self.enable_length_penalty = bool(enable_length_penalty)
self.enable_token_utilisation_reward = bool(enable_token_utilisation_reward)
self.reward_mode = reward_mode
self.correctness_weight = float(correctness_weight)
self.length_penalty_lambda = float(length_penalty_lambda)
self.stage2_length_penalty_lambda = (
None
if stage2_length_penalty_lambda is None
else float(stage2_length_penalty_lambda)
)
self.stage2_start_epoch = float(stage2_start_epoch)
self.length_penalty_interaction = length_penalty_interaction
self.strict_format_gate = bool(strict_format_gate)
self.non_strict_penalty = float(non_strict_penalty)
allowed_modes = {"additive", "weighted_length_penalty"}
if self.reward_mode not in allowed_modes:
allowed = ", ".join(sorted(allowed_modes))
raise ValueError(
f"Unknown reward_mode '{self.reward_mode}'. Allowed: {allowed}"
)
allowed_interactions = {"correctness_length", "correctness_length_format"}
if self.length_penalty_interaction not in allowed_interactions:
allowed = ", ".join(sorted(allowed_interactions))
raise ValueError(
"Unknown length_penalty_interaction "
f"'{self.length_penalty_interaction}'. Allowed: {allowed}"
)
def reward_names(self) -> list[str]:
names = [
"format_tag_reward",
"gsm8k_correctness_reward",
]
if self.enable_length_penalty:
names.append("length_penalty_reward")
if self.enable_token_utilisation_reward:
names.append("token_utilisation_reward")
return names
def combine_rewards(
self, reward_outputs: dict[str, list[float]], objective_extra: list[float]
) -> list[float]:
format_scores = reward_outputs.get("format_tag_reward", [])
del objective_extra
if self.reward_mode == "additive":
sources = list(reward_outputs.values())
if not sources:
return []
width = len(sources[0])
totals = [0.0] * width
for scores in sources:
if len(scores) != width:
raise ValueError("Reward functions returned inconsistent lengths.")
totals = [a + b for a, b in zip(totals, scores, strict=True)]
if self.strict_format_gate:
gated: list[float] = []
for idx, total in enumerate(totals):
if float(format_scores[idx]) > 0.5:
gated.append(float(total))
else:
gated.append(self.non_strict_penalty)
return gated
return totals
# weighted_length_penalty:
# default: R_total = R_correctness + lambda * (R_correctness * R_length) + R_format
# alt: R_total = R_correctness + lambda * (R_correctness * R_length * R_format) + R_format
correctness_scores = reward_outputs.get("gsm8k_correctness_reward", [])
if len(format_scores) != len(correctness_scores):
raise ValueError("format_tag_reward and gsm8k_correctness_reward must align.")
if self.enable_length_penalty:
length_scores = reward_outputs.get(
"length_penalty_reward", [0.0] * len(correctness_scores)
)
if len(length_scores) != len(correctness_scores):
raise ValueError("length_penalty_reward must align with core rewards.")
else:
length_scores = [0.0] * len(correctness_scores)
token_utilisation_scores = reward_outputs.get(
"token_utilisation_reward", [0.0] * len(correctness_scores)
)
if len(token_utilisation_scores) != len(correctness_scores):
raise ValueError("token_utilisation_reward must align with core rewards.")
totals: list[float] = []
for idx in range(len(correctness_scores)):
r_correctness = float(correctness_scores[idx])
r_format = float(format_scores[idx])
r_length = float(length_scores[idx])
if self.strict_format_gate and r_format <= 0.5:
totals.append(self.non_strict_penalty)
continue
interaction = r_correctness * r_length
if self.length_penalty_interaction == "correctness_length_format":
interaction *= r_format
total = (
self.correctness_weight * r_correctness
+ self.length_penalty_lambda * interaction
+ r_format
+ float(token_utilisation_scores[idx])
)
totals.append(total)
return totals
def extra_reward(
self,
prompts: list[str],
completions: list[str],
references: list[str],
metadata: list[dict],
) -> list[float]:
del prompts, references, metadata
return [0.0 for _ in completions]
@register_objective("latent_neuralese")
class LatentNeuraleseObjective:
name = "latent_neuralese"
def __init__(self, representation_key: str = "hidden_state", **kwargs):
self.representation_key = representation_key
self.kwargs = kwargs
def reward_names(self) -> list[str]:
# Keep token-level rewards while you bootstrap latent objectives.
return ["format_tag_reward"]
def extra_reward(
self,
prompts: list[str],
completions: list[str],
references: list[str],
metadata: list[dict],
) -> list[float]:
del prompts, references, metadata
# Baseline no-op: replace with latent-space scoring from activations.
# This is intentionally isolated so you can iterate without touching trainer internals.
return [0.0 for _ in completions]
|