"""Tests for reward components — exploration and generation.""" import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parents[1])) from constants import MAX_EXPLORE_REWARD, MAX_REPAIR_REWARD, normalized_episode_score from rewards.exploration import ( action_novelty, coverage_delta, compute_explore_reward, query_relevance, research_breadth, result_novelty, source_quality, tool_choice_score, ) from rewards.generation import ( adjust_repair_reward, compute_generate_reward, context_usage, format_match, keyword_coverage, marimo_structure, narration_score, ) from rewards.sandbox import ast_parses, validate_code from research.types import ResearchChunk, ResearchResult from task_bank import ALL_TASKS MARIMO_TASK = next(t for t in ALL_TASKS if t.topic == "Linear Regression") MANIM_TASK = next(t for t in ALL_TASKS if t.topic == "Fourier Transform") # --- Sandbox --- def test_ast_parses(): assert ast_parses("x = 1") is True assert ast_parses("not python!!!") is False def test_syntax_errors_are_verbose(): result = validate_code("marimo", "x = (1 +\n") rendered = result.render_errors() assert "PY_SYNTAX" in rendered assert "line" in rendered assert "^" in rendered def test_marimo_duplicate_definitions_fail_static_check(): code = """import marimo app = marimo.App() @app.cell def __(): x = 1 return x, @app.cell def __(): x = 2 return x, """ result = validate_code("marimo", code) assert result.parses is True assert result.check_passed is False assert "MB002" in result.error_codes def test_marimo_runtime_rejects_numpy_math_namespace(): code = """import marimo app = marimo.App() @app.cell def __(): import numpy as np value = np.math.factorial(3) return value, """ result = validate_code("marimo", code) assert result.parses is True assert result.check_passed is True assert result.exec_success is False assert "MARIMO_EXPORT" in result.error_codes assert "np.math" in result.message or "module 'numpy'" in result.message # --- Exploration rewards --- def test_query_relevance(): assert query_relevance("linear regression MSE", "Linear Regression", "linear regression,MSE") > 0.5 assert query_relevance("", "Linear Regression", "x") == 0.0 assert query_relevance("cats", "Linear Regression", "linear regression") < 0.3 def test_result_novelty(): assert result_novelty("new information here", []) == 1.0 assert result_novelty("same words again", ["same words again"]) < 0.5 assert result_novelty("", []) == 0.0 def test_action_novelty_penalizes_repeated_intent(): previous = ["search_wikipedia backpropagation algorithm neural network fundamentals"] assert action_novelty( "search_wikipedia", "backpropagation algorithm neural network", "fundamentals", previous, ) < 0.3 assert action_novelty( "fetch_docs", "marimo slider plotting examples", "interactive code patterns", previous, ) > 0.7 def test_research_breadth(): assert research_breadth([], min_sources=2) == 0.0 assert research_breadth(["a"], min_sources=2) == 0.5 assert research_breadth(["a", "b"], min_sources=2) == 1.0 def test_tool_choice_score(): assert tool_choice_score("search_arxiv", "hard", "recent research paper") == 1.0 assert tool_choice_score("fetch_docs", "easy", "marimo plotting api") == 1.0 def test_source_quality(): result = ResearchResult( tool="search_arxiv", query="linear regression", chunks=[ ResearchChunk( source="arxiv", tool="search_arxiv", title="A paper", url="https://arxiv.org/abs/1", text="linear regression least squares optimization " * 10, score=1.0, metadata={"year": 2024}, ) ], ) assert source_quality(result) > 0.7 def test_coverage_delta(): assert coverage_delta( "linear regression,MSE", "linear regression", [], "mean squared error MSE", ) > 0.0 def test_explore_reward_integration(): result = ResearchResult( tool="search_wikipedia", query="linear regression least squares", chunks=[ ResearchChunk( source="wikipedia", tool="search_wikipedia", title="Linear regression", url="https://example.test", text="Linear regression minimizes squared error with least squares.", score=1.0, metadata={"page": "Linear regression"}, ) ], ) reward, comp = compute_explore_reward( query="linear regression least squares", tool="search_wikipedia", intent="beginner explanation", result=result, topic="Linear Regression", keywords_csv="linear regression,least squares,MSE", task_content="Linear regression is a method for modeling the relationship between variables.", difficulty="easy", previous_context=[], accumulated_context=["first search result"], used_tools=set(), ) assert reward > 0.1 assert reward <= MAX_EXPLORE_REWARD assert set(comp) == { "query_quality", "evidence_quality", "information_gain", "efficiency", "explore_total", } def test_explore_reward_empty_result_is_gated(): result = ResearchResult( tool="search_wikipedia", query="linear regression least squares MSE", chunks=[], ) reward, comp = compute_explore_reward( query="linear regression least squares MSE", tool="search_wikipedia", intent="beginner explanation", result=result, topic="Linear Regression", keywords_csv="linear regression,least squares,MSE", task_content="", difficulty="easy", previous_context=[], accumulated_context=[], used_tools=set(), ) assert reward < 0.05 assert comp["evidence_quality"] == 0.0 assert comp["information_gain"] == 0.0 # --- Generation rewards --- def test_keyword_coverage(): assert keyword_coverage("linear regression MSE", "linear regression,MSE,gradient descent") > 0.5 assert keyword_coverage("nothing", "linear regression,MSE") == 0.0 def test_format_match(): assert format_match("marimo", MARIMO_TASK) == 1.0 assert format_match("manim", MARIMO_TASK) == 0.3 # Task with preferred_format=None should score 1.0 for any format no_pref_task = next(t for t in ALL_TASKS if t.preferred_format is None) assert format_match("marimo", no_pref_task) == 1.0 assert format_match("manim", no_pref_task) == 1.0 def test_narration_marimo(): assert narration_score("", "marimo") == 1.0 def test_narration_manim(): assert narration_score("", "manim") == 0.0 long_narration = ( "First we introduce the concept. Next we show the graph. " "Then we animate the transformation step by step. " "Finally we summarize the key takeaways from this scene." ) assert narration_score(long_narration, "manim") > 0.5 def test_structure_marimo(): good = """import marimo app = marimo.App() @app.cell def __(): import marimo as mo return mo, @app.cell def __(mo): mo.md("# Regression") return () @app.cell def __(): import matplotlib.pyplot as plt return plt, @app.cell def __(mo): slider = mo.ui.slider(0, 5) return slider, """ assert marimo_structure(good, MARIMO_TASK) > 0.5 def test_marimo_structure_prefers_reactive_plot_wrappers(): raw = """import marimo app = marimo.App() @app.cell def __(): import numpy as np import matplotlib.pyplot as plt return np, plt @app.cell def __(np, plt): _x = np.linspace(0, 1, 10) _fig, _ax = plt.subplots() _ax.plot(_x, _x) _fig return () """ reactive = """import marimo app = marimo.App() @app.cell def __(): import marimo as mo import numpy as np import matplotlib.pyplot as plt return mo, np, plt @app.cell def __(mo, np, plt): _x = np.linspace(0, 1, 10) _fig, _ax = plt.subplots() _ax.plot(_x, _x) mo.ui.matplotlib(plt.gca()) return () """ raw_score = marimo_structure(raw, MARIMO_TASK, static_check_passed=True) reactive_score = marimo_structure(reactive, MARIMO_TASK, static_check_passed=True) assert reactive_score > raw_score def test_context_usage(): assert context_usage("x = 1", []) == 0.0 assert context_usage( "linear regression least squares gradient descent optimization", ["linear regression least squares optimization methods"], ) > 0.3 def test_generate_reward_garbage(): reward, comp = compute_generate_reward( code="not python!!!", fmt="marimo", narration="", task=MARIMO_TASK, exec_success=False, accumulated_context=[], ) assert reward < 0.4 assert comp["validity"] == 0.0 def test_generate_reward_good(): code = """import marimo app = marimo.App() @app.cell def __(): import marimo as mo import numpy as np import matplotlib.pyplot as plt return mo, np, plt @app.cell def __(mo): mo.md("# Linear Regression") return () @app.cell def __(np): # linear regression least squares MSE gradient descent weights bias X = np.linspace(0, 10, 50) y = 2 * X + 1 return X, y @app.cell def __(mo): slider = mo.ui.slider(0, 5, value=2, label="Slope") return slider, """ reward, comp = compute_generate_reward( code=code, fmt="marimo", narration="", task=MARIMO_TASK, exec_success=True, accumulated_context=["linear regression least squares"], static_check_passed=True, ) assert reward > 0.6 assert comp["validity"] == 1.0 assert comp["task_alignment"] == 1.0 assert comp["structure"] > 0.8 assert comp["research_usage"] > 0.5 def test_marimo_static_failure_is_not_code_valid(): code = """import marimo app = marimo.App() @app.cell def __(): import matplotlib.pyplot as plt fig, ax = plt.subplots() return fig, ax @app.cell def __(): import matplotlib.pyplot as plt fig, ax = plt.subplots() return fig, ax """ reward, comp = compute_generate_reward( code=code, fmt="marimo", narration="", task=MARIMO_TASK, exec_success=False, accumulated_context=["linear regression least squares"], static_check_passed=False, error_codes=["MB002"], ) assert 0.0 < comp["validity"] < 1.0 assert reward < 0.15 def test_generate_reward_wrong_format(): code = "import marimo as mo\napp = mo.App()\n@app.cell\ndef _():\n return\n" r_right, _ = compute_generate_reward(code, "marimo", "", MARIMO_TASK, False, []) r_wrong, _ = compute_generate_reward(code, "manim", "", MARIMO_TASK, False, []) assert r_right > r_wrong def test_reward_spread(): rewards = [] for task in ALL_TASKS[:5]: for code in ["bad!!!", "x = 1", "import marimo as mo\napp = mo.App()"]: r, _ = compute_generate_reward(code, "marimo", "", task, False, []) rewards.append(r) unique = set(round(r, 3) for r in rewards) assert len(unique) >= 3 def test_repair_reward_success_is_discounted_and_changed(): reward, comp = adjust_repair_reward( 1.0, repair_success=True, previous_error_codes=["PY_SYNTAX"], new_error_codes=[], previous_code="x =", repaired_code="x = 1", ) assert reward == 0.72 assert 0.0 <= reward <= MAX_REPAIR_REWARD assert comp["repair_success"] == 1.0 assert comp["fixed_prior_errors"] == 1.0 assert comp["changed_code"] == 1.0 def test_repair_reward_penalizes_repeated_code(): changed_reward, _ = adjust_repair_reward( 1.0, repair_success=True, previous_error_codes=["PY_SYNTAX"], new_error_codes=[], previous_code="x =", repaired_code="x = 1", ) repeated_reward, comp = adjust_repair_reward( 1.0, repair_success=True, previous_error_codes=["PY_SYNTAX"], new_error_codes=[], previous_code="x =", repaired_code="x =", ) assert repeated_reward < changed_reward assert comp["changed_code"] == 0.0 def test_repair_reward_failed_fix_stays_discounted(): reward, comp = adjust_repair_reward( 0.8, repair_success=False, previous_error_codes=["MB002"], new_error_codes=["MB002"], previous_code="x = 1", repaired_code="x = 2", ) assert 0.0 < reward < MAX_REPAIR_REWARD assert comp["repair_success"] == 0.0 assert comp["fixed_prior_errors"] == 0.0 def test_normalized_episode_score_bounds(): assert normalized_episode_score(-1.0) == 0.0 assert normalized_episode_score(999.0) == 1.0 if __name__ == "__main__": tests = [ test_ast_parses, test_syntax_errors_are_verbose, test_marimo_duplicate_definitions_fail_static_check, test_marimo_runtime_rejects_numpy_math_namespace, test_query_relevance, test_result_novelty, test_action_novelty_penalizes_repeated_intent, test_research_breadth, test_tool_choice_score, test_source_quality, test_coverage_delta, test_explore_reward_integration, test_explore_reward_empty_result_is_gated, test_keyword_coverage, test_format_match, test_narration_marimo, test_narration_manim, test_structure_marimo, test_marimo_structure_prefers_reactive_plot_wrappers, test_context_usage, test_generate_reward_garbage, test_generate_reward_good, test_marimo_static_failure_is_not_code_valid, test_generate_reward_wrong_format, test_reward_spread, test_repair_reward_success_is_discounted_and_changed, test_repair_reward_penalizes_repeated_code, test_repair_reward_failed_fix_stays_discounted, test_normalized_episode_score_bounds, ] passed = 0 for t in tests: try: t() passed += 1 except Exception as e: print(f"FAIL: {t.__name__}: {e}") print(f"PASS: test_rewards ({passed}/{len(tests)})")