from __future__ import annotations from .registry import register_objective @register_objective("token_grpo") class TokenGRPOObjective: name = "token_grpo" def __init__( self, enable_length_penalty: bool = True, enable_token_utilisation_reward: bool = False, reward_mode: str = "additive", correctness_weight: float = 1.0, length_penalty_lambda: float = 1.0, stage2_length_penalty_lambda: float | None = None, stage2_start_epoch: float = 1.0, length_penalty_interaction: str = "correctness_length", strict_format_gate: bool = True, non_strict_penalty: float = -1.0, **kwargs, ): del kwargs self.enable_length_penalty = bool(enable_length_penalty) self.enable_token_utilisation_reward = bool(enable_token_utilisation_reward) self.reward_mode = reward_mode self.correctness_weight = float(correctness_weight) self.length_penalty_lambda = float(length_penalty_lambda) self.stage2_length_penalty_lambda = ( None if stage2_length_penalty_lambda is None else float(stage2_length_penalty_lambda) ) self.stage2_start_epoch = float(stage2_start_epoch) self.length_penalty_interaction = length_penalty_interaction self.strict_format_gate = bool(strict_format_gate) self.non_strict_penalty = float(non_strict_penalty) allowed_modes = {"additive", "weighted_length_penalty"} if self.reward_mode not in allowed_modes: allowed = ", ".join(sorted(allowed_modes)) raise ValueError( f"Unknown reward_mode '{self.reward_mode}'. Allowed: {allowed}" ) allowed_interactions = {"correctness_length", "correctness_length_format"} if self.length_penalty_interaction not in allowed_interactions: allowed = ", ".join(sorted(allowed_interactions)) raise ValueError( "Unknown length_penalty_interaction " f"'{self.length_penalty_interaction}'. Allowed: {allowed}" ) def reward_names(self) -> list[str]: names = [ "format_tag_reward", "gsm8k_correctness_reward", ] if self.enable_length_penalty: names.append("length_penalty_reward") if self.enable_token_utilisation_reward: names.append("token_utilisation_reward") return names def combine_rewards( self, reward_outputs: dict[str, list[float]], objective_extra: list[float] ) -> list[float]: format_scores = reward_outputs.get("format_tag_reward", []) del objective_extra if self.reward_mode == "additive": sources = list(reward_outputs.values()) if not sources: return [] width = len(sources[0]) totals = [0.0] * width for scores in sources: if len(scores) != width: raise ValueError("Reward functions returned inconsistent lengths.") totals = [a + b for a, b in zip(totals, scores, strict=True)] if self.strict_format_gate: gated: list[float] = [] for idx, total in enumerate(totals): if float(format_scores[idx]) > 0.5: gated.append(float(total)) else: gated.append(self.non_strict_penalty) return gated return totals # weighted_length_penalty: # default: R_total = R_correctness + lambda * (R_correctness * R_length) + R_format # alt: R_total = R_correctness + lambda * (R_correctness * R_length * R_format) + R_format correctness_scores = reward_outputs.get("gsm8k_correctness_reward", []) if len(format_scores) != len(correctness_scores): raise ValueError("format_tag_reward and gsm8k_correctness_reward must align.") if self.enable_length_penalty: length_scores = reward_outputs.get( "length_penalty_reward", [0.0] * len(correctness_scores) ) if len(length_scores) != len(correctness_scores): raise ValueError("length_penalty_reward must align with core rewards.") else: length_scores = [0.0] * len(correctness_scores) token_utilisation_scores = reward_outputs.get( "token_utilisation_reward", [0.0] * len(correctness_scores) ) if len(token_utilisation_scores) != len(correctness_scores): raise ValueError("token_utilisation_reward must align with core rewards.") totals: list[float] = [] for idx in range(len(correctness_scores)): r_correctness = float(correctness_scores[idx]) r_format = float(format_scores[idx]) r_length = float(length_scores[idx]) if self.strict_format_gate and r_format <= 0.5: totals.append(self.non_strict_penalty) continue interaction = r_correctness * r_length if self.length_penalty_interaction == "correctness_length_format": interaction *= r_format total = ( self.correctness_weight * r_correctness + self.length_penalty_lambda * interaction + r_format + float(token_utilisation_scores[idx]) ) totals.append(total) return totals def extra_reward( self, prompts: list[str], completions: list[str], references: list[str], metadata: list[dict], ) -> list[float]: del prompts, references, metadata return [0.0 for _ in completions] @register_objective("latent_neuralese") class LatentNeuraleseObjective: name = "latent_neuralese" def __init__(self, representation_key: str = "hidden_state", **kwargs): self.representation_key = representation_key self.kwargs = kwargs def reward_names(self) -> list[str]: # Keep token-level rewards while you bootstrap latent objectives. return ["format_tag_reward"] def extra_reward( self, prompts: list[str], completions: list[str], references: list[str], metadata: list[dict], ) -> list[float]: del prompts, references, metadata # Baseline no-op: replace with latent-space scoring from activations. # This is intentionally isolated so you can iterate without touching trainer internals. return [0.0 for _ in completions]