Spaces:
Sleeping
Sleeping
| """Reward components for the generation phase. | |
| After exploration, the agent generates marimo/manim code. Rewards measure | |
| validity, task alignment, artifact structure, and research usage. | |
| Scoring model: | |
| quality = weighted sum of (validity, task alignment, structure, research usage) | |
| total = quality × gate | |
| Gates (multiplicative): | |
| - code doesn't parse → total = 0 | |
| - static check fails → total = quality × small static-fail multiplier | |
| - code doesn't run → total = quality × execution-fail multiplier | |
| - code runs → total = quality × 1.0 | |
| """ | |
| from __future__ import annotations | |
| import re | |
| from typing import TYPE_CHECKING | |
| from .sandbox import ast_parses, check_marimo, extract_scene_class | |
| try: | |
| from ..constants import MAX_REPAIR_REWARD, clamp_action_reward | |
| except ImportError: # pragma: no cover - supports direct test execution | |
| from constants import MAX_REPAIR_REWARD, clamp_action_reward | |
| if TYPE_CHECKING: | |
| from ..task_bank import Task | |
| # --------------------------------------------------------------------------- | |
| # Component weights | |
| # --------------------------------------------------------------------------- | |
| _WEIGHTS = { | |
| "validity": 0.15, | |
| "task_alignment": 0.30, | |
| "structure": 0.30, | |
| "research_usage": 0.25, | |
| } | |
| GATE_STATIC_FAIL = 0.12 | |
| GATE_RUNS_FAIL = 0.30 # quality multiplier when static checks pass but execution fails | |
| _STOPWORDS = { | |
| "about", "after", "again", "against", "also", "because", "before", "being", | |
| "between", "class", "code", "construct", "could", "from", "have", "into", | |
| "like", "make", "more", "most", "only", "self", "show", "step", "than", | |
| "that", "their", "then", "there", "these", "this", "through", "using", | |
| "value", "where", "with", "would", | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Individual scorers | |
| # --------------------------------------------------------------------------- | |
| def keyword_coverage(code: str, keywords_csv: str) -> float: | |
| """Fraction of task keywords mentioned in the code (case-insensitive).""" | |
| if not keywords_csv: | |
| return 0.0 | |
| keywords = [k.strip().lower() for k in keywords_csv.split(",") if k.strip()] | |
| if not keywords: | |
| return 0.0 | |
| code_lower = code.lower() | |
| return sum(1 for kw in keywords if kw in code_lower) / len(keywords) | |
| def format_match(chosen_format: str, task: Task) -> float: | |
| """1.0 if format matches the task's preferred format, else 0.3. | |
| If the task has no preferred format (None), any choice scores 1.0. | |
| """ | |
| if task.preferred_format is None: | |
| return 1.0 | |
| return 1.0 if chosen_format == task.preferred_format else 0.3 | |
| def marimo_structure( | |
| code: str, | |
| task: Task, | |
| static_check_passed: bool | None = None, | |
| error_codes: list[str] | None = None, | |
| ) -> float: | |
| """Score structural quality of a marimo notebook (0-1). | |
| Additive scoring for good patterns, penalties from ``marimo check`` | |
| for breaking violations (duplicate defs, cycles, etc.). | |
| """ | |
| score = 0.0 | |
| # Positive signals | |
| if "import marimo" in code or "from marimo" in code: | |
| score += 0.2 | |
| if "marimo.App" in code or "mo.App" in code: | |
| score += 0.1 | |
| cell_count = code.count("@app.cell") | |
| if cell_count >= 3: | |
| score += 0.2 | |
| elif cell_count >= 1: | |
| score += 0.1 | |
| ui_patterns = [ | |
| "mo.md(", | |
| "mo.Html", | |
| "mo.accordion", | |
| "mo.callout", | |
| "mo.hstack(", | |
| "mo.vstack(", | |
| "mo.ui.slider", | |
| "mo.ui.dropdown", | |
| "mo.ui.table", | |
| "mo.ui.dataframe", | |
| ] | |
| score += min(0.22, sum(0.06 for p in ui_patterns if p in code)) | |
| reactive_plot_patterns = [ | |
| "mo.ui.matplotlib(", | |
| "mo.ui.plotly(", | |
| "mo.ui.altair_chart(", | |
| ] | |
| raw_plot_patterns = [ | |
| "plt.", | |
| "matplotlib.pyplot", | |
| "px.", | |
| "plotly.", | |
| "alt.Chart", | |
| ] | |
| if "mo.ui.matplotlib(plt.gca())" in code: | |
| score += 0.24 if task.data_available else 0.16 | |
| elif any(p in code for p in reactive_plot_patterns): | |
| score += 0.18 if task.data_available else 0.10 | |
| elif any(p in code for p in raw_plot_patterns): | |
| score += 0.08 if task.data_available else 0.03 | |
| score -= 0.08 | |
| if "plt.tight_layout(" in code: | |
| score -= 0.12 | |
| if "np.math." in code: | |
| score -= 0.15 | |
| tier_thresholds = {"advanced": 6, "intermediate": 4, "beginner": 2} | |
| if cell_count >= tier_thresholds.get(task.tier, 2): | |
| score += 0.1 | |
| # Marimo check: penalize breaking violations, bonus for clean code | |
| if static_check_passed is None: | |
| passed, _, violations = check_marimo(code) | |
| else: | |
| passed = static_check_passed | |
| violations = error_codes or [] | |
| if passed: | |
| score += 0.1 | |
| else: | |
| penalty = { | |
| "MB002": 0.35, | |
| "MB003": 0.4, | |
| "MB005": 0.25, | |
| "MB001": 0.3, | |
| "MB004": 0.2, | |
| } | |
| for v in violations: | |
| score -= penalty.get(v, 0.15) | |
| return max(0.0, min(1.0, score)) | |
| def manim_structure(code: str, task: Task) -> float: | |
| """Score structural quality of a manim scene (0-1).""" | |
| from .sandbox import extract_scene_class | |
| score = 0.0 | |
| if "from manim" in code or "import manim" in code: | |
| score += 0.2 | |
| if extract_scene_class(code) is not None: | |
| score += 0.2 | |
| if "def construct" in code: | |
| score += 0.1 | |
| anim_patterns = [ | |
| "self.play(", "self.wait(", "Create(", "FadeIn(", "FadeOut(", | |
| "Transform(", "Write(", "MoveToTarget", "Indicate(", | |
| "ReplacementTransform(", | |
| ] | |
| anim_hits = sum(1 for p in anim_patterns if p in code) | |
| score += min(0.3, anim_hits * 0.05) | |
| math_patterns = ["MathTex(", "Tex(", "Axes(", "NumberPlane(", "Graph("] | |
| if any(p in code for p in math_patterns): | |
| score += 0.1 | |
| tier_thresholds = {"advanced": 6, "intermediate": 4, "beginner": 2} | |
| if anim_hits >= tier_thresholds.get(task.tier, 2): | |
| score += 0.1 | |
| return min(1.0, score) | |
| def narration_score(narration: str, fmt: str) -> float: | |
| """Score narration quality. Only relevant for manim format.""" | |
| if fmt != "manim": | |
| return 1.0 | |
| if not narration or not narration.strip(): | |
| return 0.0 | |
| words = narration.split() | |
| score = 0.0 | |
| if len(words) >= 30: | |
| score += 0.4 | |
| elif len(words) >= 10: | |
| score += 0.2 | |
| scene_markers = ["scene", "step", "first", "next", "then", "finally", "now"] | |
| score += min(0.3, sum(0.1 for m in scene_markers if m in narration.lower())) | |
| if len(words) >= 50: | |
| score += 0.3 | |
| elif len(words) >= 20: | |
| score += 0.15 | |
| return min(1.0, score) | |
| def context_usage(code: str, accumulated_context: list[str]) -> float: | |
| """Score whether the generated code incorporates research findings (0-1).""" | |
| if not accumulated_context: | |
| return 0.0 | |
| context_words: set[str] = set() | |
| for ctx in accumulated_context: | |
| context_words.update(_tokens(ctx)) | |
| if not context_words: | |
| return 0.0 | |
| code_words = set(_tokens(code)) | |
| overlap = code_words & context_words | |
| if not overlap: | |
| return 0.0 | |
| # Do not reward broad generic overlap too heavily; a few meaningful terms | |
| # should help, but strong usage needs a substantial slice of the context. | |
| target = min(max(len(context_words), 1), 24) | |
| return min(1.0, len(overlap) / target * 2.5) | |
| # --------------------------------------------------------------------------- | |
| # Main reward function | |
| # --------------------------------------------------------------------------- | |
| def compute_generate_reward( | |
| code: str, | |
| fmt: str, | |
| narration: str, | |
| task: Task, | |
| exec_success: bool, | |
| accumulated_context: list[str], | |
| static_check_passed: bool | None = None, | |
| error_codes: list[str] | None = None, | |
| ) -> tuple[float, dict]: | |
| """Compute the generation-phase reward. Returns (total, components). | |
| ``python_parse_valid``, ``static_check_passed``, and ``code_runs`` act as | |
| gates. ``code_valid`` means the artifact is valid for its target format, | |
| not merely that the Python AST parses. | |
| """ | |
| parse_valid = ast_parses(code) | |
| c_parse = 1.0 if parse_valid else 0.0 | |
| if static_check_passed is None: | |
| static_check_passed = _infer_static_check(code, fmt, parse_valid) | |
| c_static = 1.0 if parse_valid and static_check_passed else 0.0 | |
| c_runs = 1.0 if exec_success else 0.0 | |
| c_coverage = keyword_coverage(code, task.keywords) | |
| c_format = format_match(fmt, task) | |
| if fmt == "marimo": | |
| c_struct = marimo_structure(code, task, static_check_passed, error_codes) | |
| else: | |
| scene_structure = manim_structure(code, task) | |
| c_struct = 0.75 * scene_structure + 0.25 * narration_score(narration, fmt) | |
| c_ctx = context_usage(code, accumulated_context) | |
| c_validity = _validity_score(c_parse, c_static, c_runs) | |
| c_alignment = 0.75 * c_coverage + 0.25 * c_format | |
| quality = ( | |
| _WEIGHTS["validity"] * c_validity | |
| + _WEIGHTS["task_alignment"] * c_alignment | |
| + _WEIGHTS["structure"] * c_struct | |
| + _WEIGHTS["research_usage"] * c_ctx | |
| ) | |
| # Apply gates | |
| if c_parse == 0.0: | |
| total = 0.0 | |
| elif c_static == 0.0: | |
| total = quality * _static_fail_multiplier(error_codes or []) | |
| elif c_runs == 0.0: | |
| total = quality * GATE_RUNS_FAIL | |
| else: | |
| total = quality | |
| components = { | |
| "validity": round(c_validity, 3), | |
| "task_alignment": round(c_alignment, 3), | |
| "structure": round(c_struct, 3), | |
| "research_usage": round(c_ctx, 3), | |
| "generate_total": round(total, 4), | |
| } | |
| return total, components | |
| def _infer_static_check(code: str, fmt: str, parse_valid: bool) -> bool: | |
| if not parse_valid: | |
| return False | |
| if fmt == "marimo": | |
| passed, _, _ = check_marimo(code) | |
| return passed | |
| if fmt == "manim": | |
| return extract_scene_class(code) is not None | |
| return False | |
| def _static_fail_multiplier(error_codes: list[str]) -> float: | |
| """Keep parseable but structurally invalid artifacts from scoring high.""" | |
| if any(code.startswith("MB") for code in error_codes): | |
| return GATE_STATIC_FAIL | |
| return min(GATE_RUNS_FAIL, GATE_STATIC_FAIL * 1.5) | |
| def _validity_score( | |
| parse_valid: float, | |
| static_check_passed: float, | |
| code_runs: float, | |
| ) -> float: | |
| if parse_valid == 0.0: | |
| return 0.0 | |
| if static_check_passed == 0.0: | |
| return 0.35 | |
| if code_runs == 0.0: | |
| return 0.70 | |
| return 1.0 | |
| def adjust_repair_reward( | |
| base_reward: float, | |
| *, | |
| repair_success: bool, | |
| previous_error_codes: list[str], | |
| new_error_codes: list[str], | |
| previous_code: str, | |
| repaired_code: str, | |
| ) -> tuple[float, dict]: | |
| """Discount repaired code but reward fixing the specific prior failure.""" | |
| changed = _fingerprint(previous_code) != _fingerprint(repaired_code) | |
| fixed_prior = bool(previous_error_codes) and not ( | |
| set(previous_error_codes) & set(new_error_codes) | |
| ) | |
| if repair_success: | |
| reward = base_reward * 0.60 | |
| reward += 0.08 if fixed_prior else 0.0 | |
| reward += 0.04 if changed else 0.0 | |
| else: | |
| reward = base_reward * 0.25 | |
| reward += 0.04 if fixed_prior else 0.0 | |
| if not changed: | |
| reward -= 0.15 | |
| reward = min(MAX_REPAIR_REWARD, clamp_action_reward(reward)) | |
| return reward, { | |
| "repair_success": 1.0 if repair_success else 0.0, | |
| "fixed_prior_errors": 1.0 if fixed_prior else 0.0, | |
| "changed_code": 1.0 if changed else 0.0, | |
| "repair_total": round(reward, 4), | |
| } | |
| def _tokens(text: str) -> list[str]: | |
| return [ | |
| w | |
| for w in re.findall(r"\w+", text.lower()) | |
| if len(w) > 3 and w not in _STOPWORDS | |
| ] | |
| def _fingerprint(code: str) -> str: | |
| return re.sub(r"\s+", "", code) | |