dmaheshwar22's picture
deploy: replace template with real demo
0dd7c80 verified
"""Composite reward β€” weighted sum of correctness, lint, runtime, length.
R = w_c * correctness + w_l * lint + w_r * runtime + w_len * length
All component rewards are in `[0, 1]` (larger = better). Composite output
range is `[0, sum(weights)]`. Default weights are set so correctness
dominates by ~20Γ— over each auxiliary signal; change in Week 5 ablations.
Absolute scale matters less than rank ordering β€” GRPO normalizes per-group
advantages anyway. What matters: for two candidates A and B where A is
more correct, `composite_reward(A) > composite_reward(B)` robustly.
"""
from __future__ import annotations
from dataclasses import dataclass
from ..sandbox.runner import RunResult
from .correctness import RewardMode, correctness_reward
from .lint import lint_reward
from .runtime import DEFAULT_BUDGET_MS, runtime_reward
DEFAULT_LENGTH_BUDGET_CHARS: int = 2000
@dataclass(frozen=True)
class RewardWeights:
"""Per-component weights for the composite reward.
Defaults put correctness 20Γ— higher than each auxiliary signal, so lint/
runtime/length act as tiebreakers among equally-correct solutions and
cannot flip a pass into a fail.
"""
correctness: float = 1.0
lint: float = 0.05
runtime: float = 0.05
length: float = 0.01
# Singleton instance so function defaults don't call the constructor each
# time (ruff B008). `RewardWeights` is frozen, so sharing this across callers
# is safe.
DEFAULT_WEIGHTS = RewardWeights()
def length_reward(code: str, *, budget_chars: int = DEFAULT_LENGTH_BUDGET_CHARS) -> float:
"""Shorter code β†’ higher reward, bounded `[0, 1]`.
`len(code) = 0` β†’ 1.0
`len(code) = budget` β†’ 0.0
`len(code) > budget` β†’ 0.0 (clamped)
"""
if budget_chars <= 0:
return 0.0
return max(0.0, (budget_chars - len(code)) / budget_chars)
def composite_reward(
code: str,
result: RunResult,
*,
weights: RewardWeights = DEFAULT_WEIGHTS,
mode: RewardMode = "binary",
runtime_budget_ms: int = DEFAULT_BUDGET_MS,
length_budget_chars: int = DEFAULT_LENGTH_BUDGET_CHARS,
) -> float:
"""Return the weighted sum of per-component rewards."""
return (
weights.correctness * correctness_reward(result, mode=mode)
+ weights.lint * lint_reward(code)
+ weights.runtime * runtime_reward(result, budget_ms=runtime_budget_ms)
+ weights.length * length_reward(code, budget_chars=length_budget_chars)
)