| from __future__ import annotations |
|
|
| from .registry import register_objective |
|
|
|
|
| @register_objective("token_grpo") |
| class TokenGRPOObjective: |
| name = "token_grpo" |
|
|
| def __init__( |
| self, |
| enable_length_penalty: bool = True, |
| enable_token_utilisation_reward: bool = False, |
| reward_mode: str = "additive", |
| correctness_weight: float = 1.0, |
| length_penalty_lambda: float = 1.0, |
| stage2_length_penalty_lambda: float | None = None, |
| stage2_start_epoch: float = 1.0, |
| length_penalty_interaction: str = "correctness_length", |
| strict_format_gate: bool = True, |
| non_strict_penalty: float = -1.0, |
| **kwargs, |
| ): |
| del kwargs |
| self.enable_length_penalty = bool(enable_length_penalty) |
| self.enable_token_utilisation_reward = bool(enable_token_utilisation_reward) |
| self.reward_mode = reward_mode |
| self.correctness_weight = float(correctness_weight) |
| self.length_penalty_lambda = float(length_penalty_lambda) |
| self.stage2_length_penalty_lambda = ( |
| None |
| if stage2_length_penalty_lambda is None |
| else float(stage2_length_penalty_lambda) |
| ) |
| self.stage2_start_epoch = float(stage2_start_epoch) |
| self.length_penalty_interaction = length_penalty_interaction |
| self.strict_format_gate = bool(strict_format_gate) |
| self.non_strict_penalty = float(non_strict_penalty) |
| allowed_modes = {"additive", "weighted_length_penalty"} |
| if self.reward_mode not in allowed_modes: |
| allowed = ", ".join(sorted(allowed_modes)) |
| raise ValueError( |
| f"Unknown reward_mode '{self.reward_mode}'. Allowed: {allowed}" |
| ) |
| allowed_interactions = {"correctness_length", "correctness_length_format"} |
| if self.length_penalty_interaction not in allowed_interactions: |
| allowed = ", ".join(sorted(allowed_interactions)) |
| raise ValueError( |
| "Unknown length_penalty_interaction " |
| f"'{self.length_penalty_interaction}'. Allowed: {allowed}" |
| ) |
|
|
| def reward_names(self) -> list[str]: |
| names = [ |
| "format_tag_reward", |
| "gsm8k_correctness_reward", |
| ] |
| if self.enable_length_penalty: |
| names.append("length_penalty_reward") |
| if self.enable_token_utilisation_reward: |
| names.append("token_utilisation_reward") |
| return names |
|
|
| def combine_rewards( |
| self, reward_outputs: dict[str, list[float]], objective_extra: list[float] |
| ) -> list[float]: |
| format_scores = reward_outputs.get("format_tag_reward", []) |
| del objective_extra |
|
|
| if self.reward_mode == "additive": |
| sources = list(reward_outputs.values()) |
| if not sources: |
| return [] |
| width = len(sources[0]) |
| totals = [0.0] * width |
| for scores in sources: |
| if len(scores) != width: |
| raise ValueError("Reward functions returned inconsistent lengths.") |
| totals = [a + b for a, b in zip(totals, scores, strict=True)] |
| if self.strict_format_gate: |
| gated: list[float] = [] |
| for idx, total in enumerate(totals): |
| if float(format_scores[idx]) > 0.5: |
| gated.append(float(total)) |
| else: |
| gated.append(self.non_strict_penalty) |
| return gated |
| return totals |
|
|
| |
| |
| |
| correctness_scores = reward_outputs.get("gsm8k_correctness_reward", []) |
| if len(format_scores) != len(correctness_scores): |
| raise ValueError("format_tag_reward and gsm8k_correctness_reward must align.") |
|
|
| if self.enable_length_penalty: |
| length_scores = reward_outputs.get( |
| "length_penalty_reward", [0.0] * len(correctness_scores) |
| ) |
| if len(length_scores) != len(correctness_scores): |
| raise ValueError("length_penalty_reward must align with core rewards.") |
| else: |
| length_scores = [0.0] * len(correctness_scores) |
| token_utilisation_scores = reward_outputs.get( |
| "token_utilisation_reward", [0.0] * len(correctness_scores) |
| ) |
| if len(token_utilisation_scores) != len(correctness_scores): |
| raise ValueError("token_utilisation_reward must align with core rewards.") |
|
|
| totals: list[float] = [] |
| for idx in range(len(correctness_scores)): |
| r_correctness = float(correctness_scores[idx]) |
| r_format = float(format_scores[idx]) |
| r_length = float(length_scores[idx]) |
| if self.strict_format_gate and r_format <= 0.5: |
| totals.append(self.non_strict_penalty) |
| continue |
| interaction = r_correctness * r_length |
| if self.length_penalty_interaction == "correctness_length_format": |
| interaction *= r_format |
| total = ( |
| self.correctness_weight * r_correctness |
| + self.length_penalty_lambda * interaction |
| + r_format |
| + float(token_utilisation_scores[idx]) |
| ) |
| totals.append(total) |
| return totals |
|
|
| def extra_reward( |
| self, |
| prompts: list[str], |
| completions: list[str], |
| references: list[str], |
| metadata: list[dict], |
| ) -> list[float]: |
| del prompts, references, metadata |
| return [0.0 for _ in completions] |
|
|
|
|
| @register_objective("latent_neuralese") |
| class LatentNeuraleseObjective: |
| name = "latent_neuralese" |
|
|
| def __init__(self, representation_key: str = "hidden_state", **kwargs): |
| self.representation_key = representation_key |
| self.kwargs = kwargs |
|
|
| def reward_names(self) -> list[str]: |
| |
| return ["format_tag_reward"] |
|
|
| def extra_reward( |
| self, |
| prompts: list[str], |
| completions: list[str], |
| references: list[str], |
| metadata: list[dict], |
| ) -> list[float]: |
| del prompts, references, metadata |
| |
| |
| return [0.0 for _ in completions] |
|
|