"""Composite reward — weighted sum of correctness, lint, runtime, length. R = w_c * correctness + w_l * lint + w_r * runtime + w_len * length All component rewards are in `[0, 1]` (larger = better). Composite output range is `[0, sum(weights)]`. Default weights are set so correctness dominates by ~20× over each auxiliary signal; change in Week 5 ablations. Absolute scale matters less than rank ordering — GRPO normalizes per-group advantages anyway. What matters: for two candidates A and B where A is more correct, `composite_reward(A) > composite_reward(B)` robustly. """ from __future__ import annotations from dataclasses import dataclass from ..sandbox.runner import RunResult from .correctness import RewardMode, correctness_reward from .lint import lint_reward from .runtime import DEFAULT_BUDGET_MS, runtime_reward DEFAULT_LENGTH_BUDGET_CHARS: int = 2000 @dataclass(frozen=True) class RewardWeights: """Per-component weights for the composite reward. Defaults put correctness 20× higher than each auxiliary signal, so lint/ runtime/length act as tiebreakers among equally-correct solutions and cannot flip a pass into a fail. """ correctness: float = 1.0 lint: float = 0.05 runtime: float = 0.05 length: float = 0.01 # Singleton instance so function defaults don't call the constructor each # time (ruff B008). `RewardWeights` is frozen, so sharing this across callers # is safe. DEFAULT_WEIGHTS = RewardWeights() def length_reward(code: str, *, budget_chars: int = DEFAULT_LENGTH_BUDGET_CHARS) -> float: """Shorter code → higher reward, bounded `[0, 1]`. `len(code) = 0` → 1.0 `len(code) = budget` → 0.0 `len(code) > budget` → 0.0 (clamped) """ if budget_chars <= 0: return 0.0 return max(0.0, (budget_chars - len(code)) / budget_chars) def composite_reward( code: str, result: RunResult, *, weights: RewardWeights = DEFAULT_WEIGHTS, mode: RewardMode = "binary", runtime_budget_ms: int = DEFAULT_BUDGET_MS, length_budget_chars: int = DEFAULT_LENGTH_BUDGET_CHARS, ) -> float: """Return the weighted sum of per-component rewards.""" return ( weights.correctness * correctness_reward(result, mode=mode) + weights.lint * lint_reward(code) + weights.runtime * runtime_reward(result, budget_ms=runtime_budget_ms) + weights.length * length_reward(code, budget_chars=length_budget_chars) )