explainer-env / rewards /generation.py
kgdrathan's picture
Upload folder using huggingface_hub
5869d56 verified
"""Reward components for the generation phase.
After exploration, the agent generates marimo/manim code. Rewards measure
validity, task alignment, artifact structure, and research usage.
Scoring model:
quality = weighted sum of (validity, task alignment, structure, research usage)
total = quality × gate
Gates (multiplicative):
- code doesn't parse → total = 0
- static check fails → total = quality × small static-fail multiplier
- code doesn't run → total = quality × execution-fail multiplier
- code runs → total = quality × 1.0
"""
from __future__ import annotations
import re
from typing import TYPE_CHECKING
from .sandbox import ast_parses, check_marimo, extract_scene_class
try:
from ..constants import MAX_REPAIR_REWARD, clamp_action_reward
except ImportError: # pragma: no cover - supports direct test execution
from constants import MAX_REPAIR_REWARD, clamp_action_reward
if TYPE_CHECKING:
from ..task_bank import Task
# ---------------------------------------------------------------------------
# Component weights
# ---------------------------------------------------------------------------
_WEIGHTS = {
"validity": 0.15,
"task_alignment": 0.30,
"structure": 0.30,
"research_usage": 0.25,
}
GATE_STATIC_FAIL = 0.12
GATE_RUNS_FAIL = 0.30 # quality multiplier when static checks pass but execution fails
_STOPWORDS = {
"about", "after", "again", "against", "also", "because", "before", "being",
"between", "class", "code", "construct", "could", "from", "have", "into",
"like", "make", "more", "most", "only", "self", "show", "step", "than",
"that", "their", "then", "there", "these", "this", "through", "using",
"value", "where", "with", "would",
}
# ---------------------------------------------------------------------------
# Individual scorers
# ---------------------------------------------------------------------------
def keyword_coverage(code: str, keywords_csv: str) -> float:
"""Fraction of task keywords mentioned in the code (case-insensitive)."""
if not keywords_csv:
return 0.0
keywords = [k.strip().lower() for k in keywords_csv.split(",") if k.strip()]
if not keywords:
return 0.0
code_lower = code.lower()
return sum(1 for kw in keywords if kw in code_lower) / len(keywords)
def format_match(chosen_format: str, task: Task) -> float:
"""1.0 if format matches the task's preferred format, else 0.3.
If the task has no preferred format (None), any choice scores 1.0.
"""
if task.preferred_format is None:
return 1.0
return 1.0 if chosen_format == task.preferred_format else 0.3
def marimo_structure(
code: str,
task: Task,
static_check_passed: bool | None = None,
error_codes: list[str] | None = None,
) -> float:
"""Score structural quality of a marimo notebook (0-1).
Additive scoring for good patterns, penalties from ``marimo check``
for breaking violations (duplicate defs, cycles, etc.).
"""
score = 0.0
# Positive signals
if "import marimo" in code or "from marimo" in code:
score += 0.2
if "marimo.App" in code or "mo.App" in code:
score += 0.1
cell_count = code.count("@app.cell")
if cell_count >= 3:
score += 0.2
elif cell_count >= 1:
score += 0.1
ui_patterns = [
"mo.md(",
"mo.Html",
"mo.accordion",
"mo.callout",
"mo.hstack(",
"mo.vstack(",
"mo.ui.slider",
"mo.ui.dropdown",
"mo.ui.table",
"mo.ui.dataframe",
]
score += min(0.22, sum(0.06 for p in ui_patterns if p in code))
reactive_plot_patterns = [
"mo.ui.matplotlib(",
"mo.ui.plotly(",
"mo.ui.altair_chart(",
]
raw_plot_patterns = [
"plt.",
"matplotlib.pyplot",
"px.",
"plotly.",
"alt.Chart",
]
if "mo.ui.matplotlib(plt.gca())" in code:
score += 0.24 if task.data_available else 0.16
elif any(p in code for p in reactive_plot_patterns):
score += 0.18 if task.data_available else 0.10
elif any(p in code for p in raw_plot_patterns):
score += 0.08 if task.data_available else 0.03
score -= 0.08
if "plt.tight_layout(" in code:
score -= 0.12
if "np.math." in code:
score -= 0.15
tier_thresholds = {"advanced": 6, "intermediate": 4, "beginner": 2}
if cell_count >= tier_thresholds.get(task.tier, 2):
score += 0.1
# Marimo check: penalize breaking violations, bonus for clean code
if static_check_passed is None:
passed, _, violations = check_marimo(code)
else:
passed = static_check_passed
violations = error_codes or []
if passed:
score += 0.1
else:
penalty = {
"MB002": 0.35,
"MB003": 0.4,
"MB005": 0.25,
"MB001": 0.3,
"MB004": 0.2,
}
for v in violations:
score -= penalty.get(v, 0.15)
return max(0.0, min(1.0, score))
def manim_structure(code: str, task: Task) -> float:
"""Score structural quality of a manim scene (0-1)."""
from .sandbox import extract_scene_class
score = 0.0
if "from manim" in code or "import manim" in code:
score += 0.2
if extract_scene_class(code) is not None:
score += 0.2
if "def construct" in code:
score += 0.1
anim_patterns = [
"self.play(", "self.wait(", "Create(", "FadeIn(", "FadeOut(",
"Transform(", "Write(", "MoveToTarget", "Indicate(",
"ReplacementTransform(",
]
anim_hits = sum(1 for p in anim_patterns if p in code)
score += min(0.3, anim_hits * 0.05)
math_patterns = ["MathTex(", "Tex(", "Axes(", "NumberPlane(", "Graph("]
if any(p in code for p in math_patterns):
score += 0.1
tier_thresholds = {"advanced": 6, "intermediate": 4, "beginner": 2}
if anim_hits >= tier_thresholds.get(task.tier, 2):
score += 0.1
return min(1.0, score)
def narration_score(narration: str, fmt: str) -> float:
"""Score narration quality. Only relevant for manim format."""
if fmt != "manim":
return 1.0
if not narration or not narration.strip():
return 0.0
words = narration.split()
score = 0.0
if len(words) >= 30:
score += 0.4
elif len(words) >= 10:
score += 0.2
scene_markers = ["scene", "step", "first", "next", "then", "finally", "now"]
score += min(0.3, sum(0.1 for m in scene_markers if m in narration.lower()))
if len(words) >= 50:
score += 0.3
elif len(words) >= 20:
score += 0.15
return min(1.0, score)
def context_usage(code: str, accumulated_context: list[str]) -> float:
"""Score whether the generated code incorporates research findings (0-1)."""
if not accumulated_context:
return 0.0
context_words: set[str] = set()
for ctx in accumulated_context:
context_words.update(_tokens(ctx))
if not context_words:
return 0.0
code_words = set(_tokens(code))
overlap = code_words & context_words
if not overlap:
return 0.0
# Do not reward broad generic overlap too heavily; a few meaningful terms
# should help, but strong usage needs a substantial slice of the context.
target = min(max(len(context_words), 1), 24)
return min(1.0, len(overlap) / target * 2.5)
# ---------------------------------------------------------------------------
# Main reward function
# ---------------------------------------------------------------------------
def compute_generate_reward(
code: str,
fmt: str,
narration: str,
task: Task,
exec_success: bool,
accumulated_context: list[str],
static_check_passed: bool | None = None,
error_codes: list[str] | None = None,
) -> tuple[float, dict]:
"""Compute the generation-phase reward. Returns (total, components).
``python_parse_valid``, ``static_check_passed``, and ``code_runs`` act as
gates. ``code_valid`` means the artifact is valid for its target format,
not merely that the Python AST parses.
"""
parse_valid = ast_parses(code)
c_parse = 1.0 if parse_valid else 0.0
if static_check_passed is None:
static_check_passed = _infer_static_check(code, fmt, parse_valid)
c_static = 1.0 if parse_valid and static_check_passed else 0.0
c_runs = 1.0 if exec_success else 0.0
c_coverage = keyword_coverage(code, task.keywords)
c_format = format_match(fmt, task)
if fmt == "marimo":
c_struct = marimo_structure(code, task, static_check_passed, error_codes)
else:
scene_structure = manim_structure(code, task)
c_struct = 0.75 * scene_structure + 0.25 * narration_score(narration, fmt)
c_ctx = context_usage(code, accumulated_context)
c_validity = _validity_score(c_parse, c_static, c_runs)
c_alignment = 0.75 * c_coverage + 0.25 * c_format
quality = (
_WEIGHTS["validity"] * c_validity
+ _WEIGHTS["task_alignment"] * c_alignment
+ _WEIGHTS["structure"] * c_struct
+ _WEIGHTS["research_usage"] * c_ctx
)
# Apply gates
if c_parse == 0.0:
total = 0.0
elif c_static == 0.0:
total = quality * _static_fail_multiplier(error_codes or [])
elif c_runs == 0.0:
total = quality * GATE_RUNS_FAIL
else:
total = quality
components = {
"validity": round(c_validity, 3),
"task_alignment": round(c_alignment, 3),
"structure": round(c_struct, 3),
"research_usage": round(c_ctx, 3),
"generate_total": round(total, 4),
}
return total, components
def _infer_static_check(code: str, fmt: str, parse_valid: bool) -> bool:
if not parse_valid:
return False
if fmt == "marimo":
passed, _, _ = check_marimo(code)
return passed
if fmt == "manim":
return extract_scene_class(code) is not None
return False
def _static_fail_multiplier(error_codes: list[str]) -> float:
"""Keep parseable but structurally invalid artifacts from scoring high."""
if any(code.startswith("MB") for code in error_codes):
return GATE_STATIC_FAIL
return min(GATE_RUNS_FAIL, GATE_STATIC_FAIL * 1.5)
def _validity_score(
parse_valid: float,
static_check_passed: float,
code_runs: float,
) -> float:
if parse_valid == 0.0:
return 0.0
if static_check_passed == 0.0:
return 0.35
if code_runs == 0.0:
return 0.70
return 1.0
def adjust_repair_reward(
base_reward: float,
*,
repair_success: bool,
previous_error_codes: list[str],
new_error_codes: list[str],
previous_code: str,
repaired_code: str,
) -> tuple[float, dict]:
"""Discount repaired code but reward fixing the specific prior failure."""
changed = _fingerprint(previous_code) != _fingerprint(repaired_code)
fixed_prior = bool(previous_error_codes) and not (
set(previous_error_codes) & set(new_error_codes)
)
if repair_success:
reward = base_reward * 0.60
reward += 0.08 if fixed_prior else 0.0
reward += 0.04 if changed else 0.0
else:
reward = base_reward * 0.25
reward += 0.04 if fixed_prior else 0.0
if not changed:
reward -= 0.15
reward = min(MAX_REPAIR_REWARD, clamp_action_reward(reward))
return reward, {
"repair_success": 1.0 if repair_success else 0.0,
"fixed_prior_errors": 1.0 if fixed_prior else 0.0,
"changed_code": 1.0 if changed else 0.0,
"repair_total": round(reward, 4),
}
def _tokens(text: str) -> list[str]:
return [
w
for w in re.findall(r"\w+", text.lower())
if len(w) > 3 and w not in _STOPWORDS
]
def _fingerprint(code: str) -> str:
return re.sub(r"\s+", "", code)