Spaces:
Running
Running
| """Tests for reward components — exploration and generation.""" | |
| import sys | |
| from pathlib import Path | |
| sys.path.insert(0, str(Path(__file__).resolve().parents[1])) | |
| from constants import MAX_EXPLORE_REWARD, MAX_REPAIR_REWARD, normalized_episode_score | |
| from rewards.exploration import ( | |
| action_novelty, | |
| coverage_delta, | |
| compute_explore_reward, | |
| query_relevance, | |
| research_breadth, | |
| result_novelty, | |
| source_quality, | |
| tool_choice_score, | |
| ) | |
| from rewards.generation import ( | |
| adjust_repair_reward, | |
| compute_generate_reward, | |
| context_usage, | |
| format_match, | |
| keyword_coverage, | |
| marimo_structure, | |
| narration_score, | |
| ) | |
| from rewards.sandbox import ast_parses, validate_code | |
| from research.types import ResearchChunk, ResearchResult | |
| from task_bank import ALL_TASKS | |
| MARIMO_TASK = next(t for t in ALL_TASKS if t.topic == "Linear Regression") | |
| MANIM_TASK = next(t for t in ALL_TASKS if t.topic == "Fourier Transform") | |
| # --- Sandbox --- | |
| def test_ast_parses(): | |
| assert ast_parses("x = 1") is True | |
| assert ast_parses("not python!!!") is False | |
| def test_syntax_errors_are_verbose(): | |
| result = validate_code("marimo", "x = (1 +\n") | |
| rendered = result.render_errors() | |
| assert "PY_SYNTAX" in rendered | |
| assert "line" in rendered | |
| assert "^" in rendered | |
| def test_marimo_duplicate_definitions_fail_static_check(): | |
| code = """import marimo | |
| app = marimo.App() | |
| @app.cell | |
| def __(): | |
| x = 1 | |
| return x, | |
| @app.cell | |
| def __(): | |
| x = 2 | |
| return x, | |
| """ | |
| result = validate_code("marimo", code) | |
| assert result.parses is True | |
| assert result.check_passed is False | |
| assert "MB002" in result.error_codes | |
| def test_marimo_runtime_rejects_numpy_math_namespace(): | |
| code = """import marimo | |
| app = marimo.App() | |
| @app.cell | |
| def __(): | |
| import numpy as np | |
| value = np.math.factorial(3) | |
| return value, | |
| """ | |
| result = validate_code("marimo", code) | |
| assert result.parses is True | |
| assert result.check_passed is True | |
| assert result.exec_success is False | |
| assert "MARIMO_EXPORT" in result.error_codes | |
| assert "np.math" in result.message or "module 'numpy'" in result.message | |
| # --- Exploration rewards --- | |
| def test_query_relevance(): | |
| assert query_relevance("linear regression MSE", "Linear Regression", "linear regression,MSE") > 0.5 | |
| assert query_relevance("", "Linear Regression", "x") == 0.0 | |
| assert query_relevance("cats", "Linear Regression", "linear regression") < 0.3 | |
| def test_result_novelty(): | |
| assert result_novelty("new information here", []) == 1.0 | |
| assert result_novelty("same words again", ["same words again"]) < 0.5 | |
| assert result_novelty("", []) == 0.0 | |
| def test_action_novelty_penalizes_repeated_intent(): | |
| previous = ["search_wikipedia backpropagation algorithm neural network fundamentals"] | |
| assert action_novelty( | |
| "search_wikipedia", | |
| "backpropagation algorithm neural network", | |
| "fundamentals", | |
| previous, | |
| ) < 0.3 | |
| assert action_novelty( | |
| "fetch_docs", | |
| "marimo slider plotting examples", | |
| "interactive code patterns", | |
| previous, | |
| ) > 0.7 | |
| def test_research_breadth(): | |
| assert research_breadth([], min_sources=2) == 0.0 | |
| assert research_breadth(["a"], min_sources=2) == 0.5 | |
| assert research_breadth(["a", "b"], min_sources=2) == 1.0 | |
| def test_tool_choice_score(): | |
| assert tool_choice_score("search_arxiv", "hard", "recent research paper") == 1.0 | |
| assert tool_choice_score("fetch_docs", "easy", "marimo plotting api") == 1.0 | |
| def test_source_quality(): | |
| result = ResearchResult( | |
| tool="search_arxiv", | |
| query="linear regression", | |
| chunks=[ | |
| ResearchChunk( | |
| source="arxiv", | |
| tool="search_arxiv", | |
| title="A paper", | |
| url="https://arxiv.org/abs/1", | |
| text="linear regression least squares optimization " * 10, | |
| score=1.0, | |
| metadata={"year": 2024}, | |
| ) | |
| ], | |
| ) | |
| assert source_quality(result) > 0.7 | |
| def test_coverage_delta(): | |
| assert coverage_delta( | |
| "linear regression,MSE", | |
| "linear regression", | |
| [], | |
| "mean squared error MSE", | |
| ) > 0.0 | |
| def test_explore_reward_integration(): | |
| result = ResearchResult( | |
| tool="search_wikipedia", | |
| query="linear regression least squares", | |
| chunks=[ | |
| ResearchChunk( | |
| source="wikipedia", | |
| tool="search_wikipedia", | |
| title="Linear regression", | |
| url="https://example.test", | |
| text="Linear regression minimizes squared error with least squares.", | |
| score=1.0, | |
| metadata={"page": "Linear regression"}, | |
| ) | |
| ], | |
| ) | |
| reward, comp = compute_explore_reward( | |
| query="linear regression least squares", | |
| tool="search_wikipedia", | |
| intent="beginner explanation", | |
| result=result, | |
| topic="Linear Regression", | |
| keywords_csv="linear regression,least squares,MSE", | |
| task_content="Linear regression is a method for modeling the relationship between variables.", | |
| difficulty="easy", | |
| previous_context=[], | |
| accumulated_context=["first search result"], | |
| used_tools=set(), | |
| ) | |
| assert reward > 0.1 | |
| assert reward <= MAX_EXPLORE_REWARD | |
| assert set(comp) == { | |
| "query_quality", | |
| "evidence_quality", | |
| "information_gain", | |
| "efficiency", | |
| "explore_total", | |
| } | |
| def test_explore_reward_empty_result_is_gated(): | |
| result = ResearchResult( | |
| tool="search_wikipedia", | |
| query="linear regression least squares MSE", | |
| chunks=[], | |
| ) | |
| reward, comp = compute_explore_reward( | |
| query="linear regression least squares MSE", | |
| tool="search_wikipedia", | |
| intent="beginner explanation", | |
| result=result, | |
| topic="Linear Regression", | |
| keywords_csv="linear regression,least squares,MSE", | |
| task_content="", | |
| difficulty="easy", | |
| previous_context=[], | |
| accumulated_context=[], | |
| used_tools=set(), | |
| ) | |
| assert reward < 0.05 | |
| assert comp["evidence_quality"] == 0.0 | |
| assert comp["information_gain"] == 0.0 | |
| # --- Generation rewards --- | |
| def test_keyword_coverage(): | |
| assert keyword_coverage("linear regression MSE", "linear regression,MSE,gradient descent") > 0.5 | |
| assert keyword_coverage("nothing", "linear regression,MSE") == 0.0 | |
| def test_format_match(): | |
| assert format_match("marimo", MARIMO_TASK) == 1.0 | |
| assert format_match("manim", MARIMO_TASK) == 0.3 | |
| # Task with preferred_format=None should score 1.0 for any format | |
| no_pref_task = next(t for t in ALL_TASKS if t.preferred_format is None) | |
| assert format_match("marimo", no_pref_task) == 1.0 | |
| assert format_match("manim", no_pref_task) == 1.0 | |
| def test_narration_marimo(): | |
| assert narration_score("", "marimo") == 1.0 | |
| def test_narration_manim(): | |
| assert narration_score("", "manim") == 0.0 | |
| long_narration = ( | |
| "First we introduce the concept. Next we show the graph. " | |
| "Then we animate the transformation step by step. " | |
| "Finally we summarize the key takeaways from this scene." | |
| ) | |
| assert narration_score(long_narration, "manim") > 0.5 | |
| def test_structure_marimo(): | |
| good = """import marimo | |
| app = marimo.App() | |
| @app.cell | |
| def __(): | |
| import marimo as mo | |
| return mo, | |
| @app.cell | |
| def __(mo): | |
| mo.md("# Regression") | |
| return () | |
| @app.cell | |
| def __(): | |
| import matplotlib.pyplot as plt | |
| return plt, | |
| @app.cell | |
| def __(mo): | |
| slider = mo.ui.slider(0, 5) | |
| return slider, | |
| """ | |
| assert marimo_structure(good, MARIMO_TASK) > 0.5 | |
| def test_marimo_structure_prefers_reactive_plot_wrappers(): | |
| raw = """import marimo | |
| app = marimo.App() | |
| @app.cell | |
| def __(): | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| return np, plt | |
| @app.cell | |
| def __(np, plt): | |
| _x = np.linspace(0, 1, 10) | |
| _fig, _ax = plt.subplots() | |
| _ax.plot(_x, _x) | |
| _fig | |
| return () | |
| """ | |
| reactive = """import marimo | |
| app = marimo.App() | |
| @app.cell | |
| def __(): | |
| import marimo as mo | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| return mo, np, plt | |
| @app.cell | |
| def __(mo, np, plt): | |
| _x = np.linspace(0, 1, 10) | |
| _fig, _ax = plt.subplots() | |
| _ax.plot(_x, _x) | |
| mo.ui.matplotlib(plt.gca()) | |
| return () | |
| """ | |
| raw_score = marimo_structure(raw, MARIMO_TASK, static_check_passed=True) | |
| reactive_score = marimo_structure(reactive, MARIMO_TASK, static_check_passed=True) | |
| assert reactive_score > raw_score | |
| def test_context_usage(): | |
| assert context_usage("x = 1", []) == 0.0 | |
| assert context_usage( | |
| "linear regression least squares gradient descent optimization", | |
| ["linear regression least squares optimization methods"], | |
| ) > 0.3 | |
| def test_generate_reward_garbage(): | |
| reward, comp = compute_generate_reward( | |
| code="not python!!!", | |
| fmt="marimo", | |
| narration="", | |
| task=MARIMO_TASK, | |
| exec_success=False, | |
| accumulated_context=[], | |
| ) | |
| assert reward < 0.4 | |
| assert comp["validity"] == 0.0 | |
| def test_generate_reward_good(): | |
| code = """import marimo | |
| app = marimo.App() | |
| @app.cell | |
| def __(): | |
| import marimo as mo | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| return mo, np, plt | |
| @app.cell | |
| def __(mo): | |
| mo.md("# Linear Regression") | |
| return () | |
| @app.cell | |
| def __(np): | |
| # linear regression least squares MSE gradient descent weights bias | |
| X = np.linspace(0, 10, 50) | |
| y = 2 * X + 1 | |
| return X, y | |
| @app.cell | |
| def __(mo): | |
| slider = mo.ui.slider(0, 5, value=2, label="Slope") | |
| return slider, | |
| """ | |
| reward, comp = compute_generate_reward( | |
| code=code, | |
| fmt="marimo", | |
| narration="", | |
| task=MARIMO_TASK, | |
| exec_success=True, | |
| accumulated_context=["linear regression least squares"], | |
| static_check_passed=True, | |
| ) | |
| assert reward > 0.6 | |
| assert comp["validity"] == 1.0 | |
| assert comp["task_alignment"] == 1.0 | |
| assert comp["structure"] > 0.8 | |
| assert comp["research_usage"] > 0.5 | |
| def test_marimo_static_failure_is_not_code_valid(): | |
| code = """import marimo | |
| app = marimo.App() | |
| @app.cell | |
| def __(): | |
| import matplotlib.pyplot as plt | |
| fig, ax = plt.subplots() | |
| return fig, ax | |
| @app.cell | |
| def __(): | |
| import matplotlib.pyplot as plt | |
| fig, ax = plt.subplots() | |
| return fig, ax | |
| """ | |
| reward, comp = compute_generate_reward( | |
| code=code, | |
| fmt="marimo", | |
| narration="", | |
| task=MARIMO_TASK, | |
| exec_success=False, | |
| accumulated_context=["linear regression least squares"], | |
| static_check_passed=False, | |
| error_codes=["MB002"], | |
| ) | |
| assert 0.0 < comp["validity"] < 1.0 | |
| assert reward < 0.15 | |
| def test_generate_reward_wrong_format(): | |
| code = "import marimo as mo\napp = mo.App()\n@app.cell\ndef _():\n return\n" | |
| r_right, _ = compute_generate_reward(code, "marimo", "", MARIMO_TASK, False, []) | |
| r_wrong, _ = compute_generate_reward(code, "manim", "", MARIMO_TASK, False, []) | |
| assert r_right > r_wrong | |
| def test_reward_spread(): | |
| rewards = [] | |
| for task in ALL_TASKS[:5]: | |
| for code in ["bad!!!", "x = 1", "import marimo as mo\napp = mo.App()"]: | |
| r, _ = compute_generate_reward(code, "marimo", "", task, False, []) | |
| rewards.append(r) | |
| unique = set(round(r, 3) for r in rewards) | |
| assert len(unique) >= 3 | |
| def test_repair_reward_success_is_discounted_and_changed(): | |
| reward, comp = adjust_repair_reward( | |
| 1.0, | |
| repair_success=True, | |
| previous_error_codes=["PY_SYNTAX"], | |
| new_error_codes=[], | |
| previous_code="x =", | |
| repaired_code="x = 1", | |
| ) | |
| assert reward == 0.72 | |
| assert 0.0 <= reward <= MAX_REPAIR_REWARD | |
| assert comp["repair_success"] == 1.0 | |
| assert comp["fixed_prior_errors"] == 1.0 | |
| assert comp["changed_code"] == 1.0 | |
| def test_repair_reward_penalizes_repeated_code(): | |
| changed_reward, _ = adjust_repair_reward( | |
| 1.0, | |
| repair_success=True, | |
| previous_error_codes=["PY_SYNTAX"], | |
| new_error_codes=[], | |
| previous_code="x =", | |
| repaired_code="x = 1", | |
| ) | |
| repeated_reward, comp = adjust_repair_reward( | |
| 1.0, | |
| repair_success=True, | |
| previous_error_codes=["PY_SYNTAX"], | |
| new_error_codes=[], | |
| previous_code="x =", | |
| repaired_code="x =", | |
| ) | |
| assert repeated_reward < changed_reward | |
| assert comp["changed_code"] == 0.0 | |
| def test_repair_reward_failed_fix_stays_discounted(): | |
| reward, comp = adjust_repair_reward( | |
| 0.8, | |
| repair_success=False, | |
| previous_error_codes=["MB002"], | |
| new_error_codes=["MB002"], | |
| previous_code="x = 1", | |
| repaired_code="x = 2", | |
| ) | |
| assert 0.0 < reward < MAX_REPAIR_REWARD | |
| assert comp["repair_success"] == 0.0 | |
| assert comp["fixed_prior_errors"] == 0.0 | |
| def test_normalized_episode_score_bounds(): | |
| assert normalized_episode_score(-1.0) == 0.0 | |
| assert normalized_episode_score(999.0) == 1.0 | |
| if __name__ == "__main__": | |
| tests = [ | |
| test_ast_parses, | |
| test_syntax_errors_are_verbose, | |
| test_marimo_duplicate_definitions_fail_static_check, | |
| test_marimo_runtime_rejects_numpy_math_namespace, | |
| test_query_relevance, | |
| test_result_novelty, | |
| test_action_novelty_penalizes_repeated_intent, | |
| test_research_breadth, | |
| test_tool_choice_score, | |
| test_source_quality, | |
| test_coverage_delta, | |
| test_explore_reward_integration, | |
| test_explore_reward_empty_result_is_gated, | |
| test_keyword_coverage, | |
| test_format_match, | |
| test_narration_marimo, | |
| test_narration_manim, | |
| test_structure_marimo, | |
| test_marimo_structure_prefers_reactive_plot_wrappers, | |
| test_context_usage, | |
| test_generate_reward_garbage, | |
| test_generate_reward_good, | |
| test_marimo_static_failure_is_not_code_valid, | |
| test_generate_reward_wrong_format, | |
| test_reward_spread, | |
| test_repair_reward_success_is_discounted_and_changed, | |
| test_repair_reward_penalizes_repeated_code, | |
| test_repair_reward_failed_fix_stays_discounted, | |
| test_normalized_episode_score_bounds, | |
| ] | |
| passed = 0 | |
| for t in tests: | |
| try: | |
| t() | |
| passed += 1 | |
| except Exception as e: | |
| print(f"FAIL: {t.__name__}: {e}") | |
| print(f"PASS: test_rewards ({passed}/{len(tests)})") | |