"""Reward components for the generation phase. After exploration, the agent generates marimo/manim code. Rewards measure validity, task alignment, artifact structure, and research usage. Scoring model: quality = weighted sum of (validity, task alignment, structure, research usage) total = quality × gate Gates (multiplicative): - code doesn't parse → total = 0 - static check fails → total = quality × small static-fail multiplier - code doesn't run → total = quality × execution-fail multiplier - code runs → total = quality × 1.0 """ from __future__ import annotations import re from typing import TYPE_CHECKING from .sandbox import ast_parses, check_marimo, extract_scene_class try: from ..constants import MAX_REPAIR_REWARD, clamp_action_reward except ImportError: # pragma: no cover - supports direct test execution from constants import MAX_REPAIR_REWARD, clamp_action_reward if TYPE_CHECKING: from ..task_bank import Task # --------------------------------------------------------------------------- # Component weights # --------------------------------------------------------------------------- _WEIGHTS = { "validity": 0.15, "task_alignment": 0.30, "structure": 0.30, "research_usage": 0.25, } GATE_STATIC_FAIL = 0.12 GATE_RUNS_FAIL = 0.30 # quality multiplier when static checks pass but execution fails _STOPWORDS = { "about", "after", "again", "against", "also", "because", "before", "being", "between", "class", "code", "construct", "could", "from", "have", "into", "like", "make", "more", "most", "only", "self", "show", "step", "than", "that", "their", "then", "there", "these", "this", "through", "using", "value", "where", "with", "would", } # --------------------------------------------------------------------------- # Individual scorers # --------------------------------------------------------------------------- def keyword_coverage(code: str, keywords_csv: str) -> float: """Fraction of task keywords mentioned in the code (case-insensitive).""" if not keywords_csv: return 0.0 keywords = [k.strip().lower() for k in keywords_csv.split(",") if k.strip()] if not keywords: return 0.0 code_lower = code.lower() return sum(1 for kw in keywords if kw in code_lower) / len(keywords) def format_match(chosen_format: str, task: Task) -> float: """1.0 if format matches the task's preferred format, else 0.3. If the task has no preferred format (None), any choice scores 1.0. """ if task.preferred_format is None: return 1.0 return 1.0 if chosen_format == task.preferred_format else 0.3 def marimo_structure( code: str, task: Task, static_check_passed: bool | None = None, error_codes: list[str] | None = None, ) -> float: """Score structural quality of a marimo notebook (0-1). Additive scoring for good patterns, penalties from ``marimo check`` for breaking violations (duplicate defs, cycles, etc.). """ score = 0.0 # Positive signals if "import marimo" in code or "from marimo" in code: score += 0.2 if "marimo.App" in code or "mo.App" in code: score += 0.1 cell_count = code.count("@app.cell") if cell_count >= 3: score += 0.2 elif cell_count >= 1: score += 0.1 ui_patterns = [ "mo.md(", "mo.Html", "mo.accordion", "mo.callout", "mo.hstack(", "mo.vstack(", "mo.ui.slider", "mo.ui.dropdown", "mo.ui.table", "mo.ui.dataframe", ] score += min(0.22, sum(0.06 for p in ui_patterns if p in code)) reactive_plot_patterns = [ "mo.ui.matplotlib(", "mo.ui.plotly(", "mo.ui.altair_chart(", ] raw_plot_patterns = [ "plt.", "matplotlib.pyplot", "px.", "plotly.", "alt.Chart", ] if "mo.ui.matplotlib(plt.gca())" in code: score += 0.24 if task.data_available else 0.16 elif any(p in code for p in reactive_plot_patterns): score += 0.18 if task.data_available else 0.10 elif any(p in code for p in raw_plot_patterns): score += 0.08 if task.data_available else 0.03 score -= 0.08 if "plt.tight_layout(" in code: score -= 0.12 if "np.math." in code: score -= 0.15 tier_thresholds = {"advanced": 6, "intermediate": 4, "beginner": 2} if cell_count >= tier_thresholds.get(task.tier, 2): score += 0.1 # Marimo check: penalize breaking violations, bonus for clean code if static_check_passed is None: passed, _, violations = check_marimo(code) else: passed = static_check_passed violations = error_codes or [] if passed: score += 0.1 else: penalty = { "MB002": 0.35, "MB003": 0.4, "MB005": 0.25, "MB001": 0.3, "MB004": 0.2, } for v in violations: score -= penalty.get(v, 0.15) return max(0.0, min(1.0, score)) def manim_structure(code: str, task: Task) -> float: """Score structural quality of a manim scene (0-1).""" from .sandbox import extract_scene_class score = 0.0 if "from manim" in code or "import manim" in code: score += 0.2 if extract_scene_class(code) is not None: score += 0.2 if "def construct" in code: score += 0.1 anim_patterns = [ "self.play(", "self.wait(", "Create(", "FadeIn(", "FadeOut(", "Transform(", "Write(", "MoveToTarget", "Indicate(", "ReplacementTransform(", ] anim_hits = sum(1 for p in anim_patterns if p in code) score += min(0.3, anim_hits * 0.05) math_patterns = ["MathTex(", "Tex(", "Axes(", "NumberPlane(", "Graph("] if any(p in code for p in math_patterns): score += 0.1 tier_thresholds = {"advanced": 6, "intermediate": 4, "beginner": 2} if anim_hits >= tier_thresholds.get(task.tier, 2): score += 0.1 return min(1.0, score) def narration_score(narration: str, fmt: str) -> float: """Score narration quality. Only relevant for manim format.""" if fmt != "manim": return 1.0 if not narration or not narration.strip(): return 0.0 words = narration.split() score = 0.0 if len(words) >= 30: score += 0.4 elif len(words) >= 10: score += 0.2 scene_markers = ["scene", "step", "first", "next", "then", "finally", "now"] score += min(0.3, sum(0.1 for m in scene_markers if m in narration.lower())) if len(words) >= 50: score += 0.3 elif len(words) >= 20: score += 0.15 return min(1.0, score) def context_usage(code: str, accumulated_context: list[str]) -> float: """Score whether the generated code incorporates research findings (0-1).""" if not accumulated_context: return 0.0 context_words: set[str] = set() for ctx in accumulated_context: context_words.update(_tokens(ctx)) if not context_words: return 0.0 code_words = set(_tokens(code)) overlap = code_words & context_words if not overlap: return 0.0 # Do not reward broad generic overlap too heavily; a few meaningful terms # should help, but strong usage needs a substantial slice of the context. target = min(max(len(context_words), 1), 24) return min(1.0, len(overlap) / target * 2.5) # --------------------------------------------------------------------------- # Main reward function # --------------------------------------------------------------------------- def compute_generate_reward( code: str, fmt: str, narration: str, task: Task, exec_success: bool, accumulated_context: list[str], static_check_passed: bool | None = None, error_codes: list[str] | None = None, ) -> tuple[float, dict]: """Compute the generation-phase reward. Returns (total, components). ``python_parse_valid``, ``static_check_passed``, and ``code_runs`` act as gates. ``code_valid`` means the artifact is valid for its target format, not merely that the Python AST parses. """ parse_valid = ast_parses(code) c_parse = 1.0 if parse_valid else 0.0 if static_check_passed is None: static_check_passed = _infer_static_check(code, fmt, parse_valid) c_static = 1.0 if parse_valid and static_check_passed else 0.0 c_runs = 1.0 if exec_success else 0.0 c_coverage = keyword_coverage(code, task.keywords) c_format = format_match(fmt, task) if fmt == "marimo": c_struct = marimo_structure(code, task, static_check_passed, error_codes) else: scene_structure = manim_structure(code, task) c_struct = 0.75 * scene_structure + 0.25 * narration_score(narration, fmt) c_ctx = context_usage(code, accumulated_context) c_validity = _validity_score(c_parse, c_static, c_runs) c_alignment = 0.75 * c_coverage + 0.25 * c_format quality = ( _WEIGHTS["validity"] * c_validity + _WEIGHTS["task_alignment"] * c_alignment + _WEIGHTS["structure"] * c_struct + _WEIGHTS["research_usage"] * c_ctx ) # Apply gates if c_parse == 0.0: total = 0.0 elif c_static == 0.0: total = quality * _static_fail_multiplier(error_codes or []) elif c_runs == 0.0: total = quality * GATE_RUNS_FAIL else: total = quality components = { "validity": round(c_validity, 3), "task_alignment": round(c_alignment, 3), "structure": round(c_struct, 3), "research_usage": round(c_ctx, 3), "generate_total": round(total, 4), } return total, components def _infer_static_check(code: str, fmt: str, parse_valid: bool) -> bool: if not parse_valid: return False if fmt == "marimo": passed, _, _ = check_marimo(code) return passed if fmt == "manim": return extract_scene_class(code) is not None return False def _static_fail_multiplier(error_codes: list[str]) -> float: """Keep parseable but structurally invalid artifacts from scoring high.""" if any(code.startswith("MB") for code in error_codes): return GATE_STATIC_FAIL return min(GATE_RUNS_FAIL, GATE_STATIC_FAIL * 1.5) def _validity_score( parse_valid: float, static_check_passed: float, code_runs: float, ) -> float: if parse_valid == 0.0: return 0.0 if static_check_passed == 0.0: return 0.35 if code_runs == 0.0: return 0.70 return 1.0 def adjust_repair_reward( base_reward: float, *, repair_success: bool, previous_error_codes: list[str], new_error_codes: list[str], previous_code: str, repaired_code: str, ) -> tuple[float, dict]: """Discount repaired code but reward fixing the specific prior failure.""" changed = _fingerprint(previous_code) != _fingerprint(repaired_code) fixed_prior = bool(previous_error_codes) and not ( set(previous_error_codes) & set(new_error_codes) ) if repair_success: reward = base_reward * 0.60 reward += 0.08 if fixed_prior else 0.0 reward += 0.04 if changed else 0.0 else: reward = base_reward * 0.25 reward += 0.04 if fixed_prior else 0.0 if not changed: reward -= 0.15 reward = min(MAX_REPAIR_REWARD, clamp_action_reward(reward)) return reward, { "repair_success": 1.0 if repair_success else 0.0, "fixed_prior_errors": 1.0 if fixed_prior else 0.0, "changed_code": 1.0 if changed else 0.0, "repair_total": round(reward, 4), } def _tokens(text: str) -> list[str]: return [ w for w in re.findall(r"\w+", text.lower()) if len(w) > 3 and w not in _STOPWORDS ] def _fingerprint(code: str) -> str: return re.sub(r"\s+", "", code)