neuralese_temp / src /hackable /objectives.py
psidharth567's picture
Export neuralese codebase (cache and .env excluded).
dbc69f3
from __future__ import annotations
from .registry import register_objective
@register_objective("token_grpo")
class TokenGRPOObjective:
name = "token_grpo"
def __init__(
self,
enable_length_penalty: bool = True,
enable_token_utilisation_reward: bool = False,
reward_mode: str = "additive",
correctness_weight: float = 1.0,
length_penalty_lambda: float = 1.0,
stage2_length_penalty_lambda: float | None = None,
stage2_start_epoch: float = 1.0,
length_penalty_interaction: str = "correctness_length",
strict_format_gate: bool = True,
non_strict_penalty: float = -1.0,
**kwargs,
):
del kwargs
self.enable_length_penalty = bool(enable_length_penalty)
self.enable_token_utilisation_reward = bool(enable_token_utilisation_reward)
self.reward_mode = reward_mode
self.correctness_weight = float(correctness_weight)
self.length_penalty_lambda = float(length_penalty_lambda)
self.stage2_length_penalty_lambda = (
None
if stage2_length_penalty_lambda is None
else float(stage2_length_penalty_lambda)
)
self.stage2_start_epoch = float(stage2_start_epoch)
self.length_penalty_interaction = length_penalty_interaction
self.strict_format_gate = bool(strict_format_gate)
self.non_strict_penalty = float(non_strict_penalty)
allowed_modes = {"additive", "weighted_length_penalty"}
if self.reward_mode not in allowed_modes:
allowed = ", ".join(sorted(allowed_modes))
raise ValueError(
f"Unknown reward_mode '{self.reward_mode}'. Allowed: {allowed}"
)
allowed_interactions = {"correctness_length", "correctness_length_format"}
if self.length_penalty_interaction not in allowed_interactions:
allowed = ", ".join(sorted(allowed_interactions))
raise ValueError(
"Unknown length_penalty_interaction "
f"'{self.length_penalty_interaction}'. Allowed: {allowed}"
)
def reward_names(self) -> list[str]:
names = [
"format_tag_reward",
"gsm8k_correctness_reward",
]
if self.enable_length_penalty:
names.append("length_penalty_reward")
if self.enable_token_utilisation_reward:
names.append("token_utilisation_reward")
return names
def combine_rewards(
self, reward_outputs: dict[str, list[float]], objective_extra: list[float]
) -> list[float]:
format_scores = reward_outputs.get("format_tag_reward", [])
del objective_extra
if self.reward_mode == "additive":
sources = list(reward_outputs.values())
if not sources:
return []
width = len(sources[0])
totals = [0.0] * width
for scores in sources:
if len(scores) != width:
raise ValueError("Reward functions returned inconsistent lengths.")
totals = [a + b for a, b in zip(totals, scores, strict=True)]
if self.strict_format_gate:
gated: list[float] = []
for idx, total in enumerate(totals):
if float(format_scores[idx]) > 0.5:
gated.append(float(total))
else:
gated.append(self.non_strict_penalty)
return gated
return totals
# weighted_length_penalty:
# default: R_total = R_correctness + lambda * (R_correctness * R_length) + R_format
# alt: R_total = R_correctness + lambda * (R_correctness * R_length * R_format) + R_format
correctness_scores = reward_outputs.get("gsm8k_correctness_reward", [])
if len(format_scores) != len(correctness_scores):
raise ValueError("format_tag_reward and gsm8k_correctness_reward must align.")
if self.enable_length_penalty:
length_scores = reward_outputs.get(
"length_penalty_reward", [0.0] * len(correctness_scores)
)
if len(length_scores) != len(correctness_scores):
raise ValueError("length_penalty_reward must align with core rewards.")
else:
length_scores = [0.0] * len(correctness_scores)
token_utilisation_scores = reward_outputs.get(
"token_utilisation_reward", [0.0] * len(correctness_scores)
)
if len(token_utilisation_scores) != len(correctness_scores):
raise ValueError("token_utilisation_reward must align with core rewards.")
totals: list[float] = []
for idx in range(len(correctness_scores)):
r_correctness = float(correctness_scores[idx])
r_format = float(format_scores[idx])
r_length = float(length_scores[idx])
if self.strict_format_gate and r_format <= 0.5:
totals.append(self.non_strict_penalty)
continue
interaction = r_correctness * r_length
if self.length_penalty_interaction == "correctness_length_format":
interaction *= r_format
total = (
self.correctness_weight * r_correctness
+ self.length_penalty_lambda * interaction
+ r_format
+ float(token_utilisation_scores[idx])
)
totals.append(total)
return totals
def extra_reward(
self,
prompts: list[str],
completions: list[str],
references: list[str],
metadata: list[dict],
) -> list[float]:
del prompts, references, metadata
return [0.0 for _ in completions]
@register_objective("latent_neuralese")
class LatentNeuraleseObjective:
name = "latent_neuralese"
def __init__(self, representation_key: str = "hidden_state", **kwargs):
self.representation_key = representation_key
self.kwargs = kwargs
def reward_names(self) -> list[str]:
# Keep token-level rewards while you bootstrap latent objectives.
return ["format_tag_reward"]
def extra_reward(
self,
prompts: list[str],
completions: list[str],
references: list[str],
metadata: list[dict],
) -> list[float]:
del prompts, references, metadata
# Baseline no-op: replace with latent-space scoring from activations.
# This is intentionally isolated so you can iterate without touching trainer internals.
return [0.0 for _ in completions]