explainer-env / tests /test_rewards.py
kgdrathan's picture
Upload folder using huggingface_hub
5869d56 verified
"""Tests for reward components — exploration and generation."""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from constants import MAX_EXPLORE_REWARD, MAX_REPAIR_REWARD, normalized_episode_score
from rewards.exploration import (
action_novelty,
coverage_delta,
compute_explore_reward,
query_relevance,
research_breadth,
result_novelty,
source_quality,
tool_choice_score,
)
from rewards.generation import (
adjust_repair_reward,
compute_generate_reward,
context_usage,
format_match,
keyword_coverage,
marimo_structure,
narration_score,
)
from rewards.sandbox import ast_parses, validate_code
from research.types import ResearchChunk, ResearchResult
from task_bank import ALL_TASKS
MARIMO_TASK = next(t for t in ALL_TASKS if t.topic == "Linear Regression")
MANIM_TASK = next(t for t in ALL_TASKS if t.topic == "Fourier Transform")
# --- Sandbox ---
def test_ast_parses():
assert ast_parses("x = 1") is True
assert ast_parses("not python!!!") is False
def test_syntax_errors_are_verbose():
result = validate_code("marimo", "x = (1 +\n")
rendered = result.render_errors()
assert "PY_SYNTAX" in rendered
assert "line" in rendered
assert "^" in rendered
def test_marimo_duplicate_definitions_fail_static_check():
code = """import marimo
app = marimo.App()
@app.cell
def __():
x = 1
return x,
@app.cell
def __():
x = 2
return x,
"""
result = validate_code("marimo", code)
assert result.parses is True
assert result.check_passed is False
assert "MB002" in result.error_codes
def test_marimo_runtime_rejects_numpy_math_namespace():
code = """import marimo
app = marimo.App()
@app.cell
def __():
import numpy as np
value = np.math.factorial(3)
return value,
"""
result = validate_code("marimo", code)
assert result.parses is True
assert result.check_passed is True
assert result.exec_success is False
assert "MARIMO_EXPORT" in result.error_codes
assert "np.math" in result.message or "module 'numpy'" in result.message
# --- Exploration rewards ---
def test_query_relevance():
assert query_relevance("linear regression MSE", "Linear Regression", "linear regression,MSE") > 0.5
assert query_relevance("", "Linear Regression", "x") == 0.0
assert query_relevance("cats", "Linear Regression", "linear regression") < 0.3
def test_result_novelty():
assert result_novelty("new information here", []) == 1.0
assert result_novelty("same words again", ["same words again"]) < 0.5
assert result_novelty("", []) == 0.0
def test_action_novelty_penalizes_repeated_intent():
previous = ["search_wikipedia backpropagation algorithm neural network fundamentals"]
assert action_novelty(
"search_wikipedia",
"backpropagation algorithm neural network",
"fundamentals",
previous,
) < 0.3
assert action_novelty(
"fetch_docs",
"marimo slider plotting examples",
"interactive code patterns",
previous,
) > 0.7
def test_research_breadth():
assert research_breadth([], min_sources=2) == 0.0
assert research_breadth(["a"], min_sources=2) == 0.5
assert research_breadth(["a", "b"], min_sources=2) == 1.0
def test_tool_choice_score():
assert tool_choice_score("search_arxiv", "hard", "recent research paper") == 1.0
assert tool_choice_score("fetch_docs", "easy", "marimo plotting api") == 1.0
def test_source_quality():
result = ResearchResult(
tool="search_arxiv",
query="linear regression",
chunks=[
ResearchChunk(
source="arxiv",
tool="search_arxiv",
title="A paper",
url="https://arxiv.org/abs/1",
text="linear regression least squares optimization " * 10,
score=1.0,
metadata={"year": 2024},
)
],
)
assert source_quality(result) > 0.7
def test_coverage_delta():
assert coverage_delta(
"linear regression,MSE",
"linear regression",
[],
"mean squared error MSE",
) > 0.0
def test_explore_reward_integration():
result = ResearchResult(
tool="search_wikipedia",
query="linear regression least squares",
chunks=[
ResearchChunk(
source="wikipedia",
tool="search_wikipedia",
title="Linear regression",
url="https://example.test",
text="Linear regression minimizes squared error with least squares.",
score=1.0,
metadata={"page": "Linear regression"},
)
],
)
reward, comp = compute_explore_reward(
query="linear regression least squares",
tool="search_wikipedia",
intent="beginner explanation",
result=result,
topic="Linear Regression",
keywords_csv="linear regression,least squares,MSE",
task_content="Linear regression is a method for modeling the relationship between variables.",
difficulty="easy",
previous_context=[],
accumulated_context=["first search result"],
used_tools=set(),
)
assert reward > 0.1
assert reward <= MAX_EXPLORE_REWARD
assert set(comp) == {
"query_quality",
"evidence_quality",
"information_gain",
"efficiency",
"explore_total",
}
def test_explore_reward_empty_result_is_gated():
result = ResearchResult(
tool="search_wikipedia",
query="linear regression least squares MSE",
chunks=[],
)
reward, comp = compute_explore_reward(
query="linear regression least squares MSE",
tool="search_wikipedia",
intent="beginner explanation",
result=result,
topic="Linear Regression",
keywords_csv="linear regression,least squares,MSE",
task_content="",
difficulty="easy",
previous_context=[],
accumulated_context=[],
used_tools=set(),
)
assert reward < 0.05
assert comp["evidence_quality"] == 0.0
assert comp["information_gain"] == 0.0
# --- Generation rewards ---
def test_keyword_coverage():
assert keyword_coverage("linear regression MSE", "linear regression,MSE,gradient descent") > 0.5
assert keyword_coverage("nothing", "linear regression,MSE") == 0.0
def test_format_match():
assert format_match("marimo", MARIMO_TASK) == 1.0
assert format_match("manim", MARIMO_TASK) == 0.3
# Task with preferred_format=None should score 1.0 for any format
no_pref_task = next(t for t in ALL_TASKS if t.preferred_format is None)
assert format_match("marimo", no_pref_task) == 1.0
assert format_match("manim", no_pref_task) == 1.0
def test_narration_marimo():
assert narration_score("", "marimo") == 1.0
def test_narration_manim():
assert narration_score("", "manim") == 0.0
long_narration = (
"First we introduce the concept. Next we show the graph. "
"Then we animate the transformation step by step. "
"Finally we summarize the key takeaways from this scene."
)
assert narration_score(long_narration, "manim") > 0.5
def test_structure_marimo():
good = """import marimo
app = marimo.App()
@app.cell
def __():
import marimo as mo
return mo,
@app.cell
def __(mo):
mo.md("# Regression")
return ()
@app.cell
def __():
import matplotlib.pyplot as plt
return plt,
@app.cell
def __(mo):
slider = mo.ui.slider(0, 5)
return slider,
"""
assert marimo_structure(good, MARIMO_TASK) > 0.5
def test_marimo_structure_prefers_reactive_plot_wrappers():
raw = """import marimo
app = marimo.App()
@app.cell
def __():
import numpy as np
import matplotlib.pyplot as plt
return np, plt
@app.cell
def __(np, plt):
_x = np.linspace(0, 1, 10)
_fig, _ax = plt.subplots()
_ax.plot(_x, _x)
_fig
return ()
"""
reactive = """import marimo
app = marimo.App()
@app.cell
def __():
import marimo as mo
import numpy as np
import matplotlib.pyplot as plt
return mo, np, plt
@app.cell
def __(mo, np, plt):
_x = np.linspace(0, 1, 10)
_fig, _ax = plt.subplots()
_ax.plot(_x, _x)
mo.ui.matplotlib(plt.gca())
return ()
"""
raw_score = marimo_structure(raw, MARIMO_TASK, static_check_passed=True)
reactive_score = marimo_structure(reactive, MARIMO_TASK, static_check_passed=True)
assert reactive_score > raw_score
def test_context_usage():
assert context_usage("x = 1", []) == 0.0
assert context_usage(
"linear regression least squares gradient descent optimization",
["linear regression least squares optimization methods"],
) > 0.3
def test_generate_reward_garbage():
reward, comp = compute_generate_reward(
code="not python!!!",
fmt="marimo",
narration="",
task=MARIMO_TASK,
exec_success=False,
accumulated_context=[],
)
assert reward < 0.4
assert comp["validity"] == 0.0
def test_generate_reward_good():
code = """import marimo
app = marimo.App()
@app.cell
def __():
import marimo as mo
import numpy as np
import matplotlib.pyplot as plt
return mo, np, plt
@app.cell
def __(mo):
mo.md("# Linear Regression")
return ()
@app.cell
def __(np):
# linear regression least squares MSE gradient descent weights bias
X = np.linspace(0, 10, 50)
y = 2 * X + 1
return X, y
@app.cell
def __(mo):
slider = mo.ui.slider(0, 5, value=2, label="Slope")
return slider,
"""
reward, comp = compute_generate_reward(
code=code,
fmt="marimo",
narration="",
task=MARIMO_TASK,
exec_success=True,
accumulated_context=["linear regression least squares"],
static_check_passed=True,
)
assert reward > 0.6
assert comp["validity"] == 1.0
assert comp["task_alignment"] == 1.0
assert comp["structure"] > 0.8
assert comp["research_usage"] > 0.5
def test_marimo_static_failure_is_not_code_valid():
code = """import marimo
app = marimo.App()
@app.cell
def __():
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
return fig, ax
@app.cell
def __():
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
return fig, ax
"""
reward, comp = compute_generate_reward(
code=code,
fmt="marimo",
narration="",
task=MARIMO_TASK,
exec_success=False,
accumulated_context=["linear regression least squares"],
static_check_passed=False,
error_codes=["MB002"],
)
assert 0.0 < comp["validity"] < 1.0
assert reward < 0.15
def test_generate_reward_wrong_format():
code = "import marimo as mo\napp = mo.App()\n@app.cell\ndef _():\n return\n"
r_right, _ = compute_generate_reward(code, "marimo", "", MARIMO_TASK, False, [])
r_wrong, _ = compute_generate_reward(code, "manim", "", MARIMO_TASK, False, [])
assert r_right > r_wrong
def test_reward_spread():
rewards = []
for task in ALL_TASKS[:5]:
for code in ["bad!!!", "x = 1", "import marimo as mo\napp = mo.App()"]:
r, _ = compute_generate_reward(code, "marimo", "", task, False, [])
rewards.append(r)
unique = set(round(r, 3) for r in rewards)
assert len(unique) >= 3
def test_repair_reward_success_is_discounted_and_changed():
reward, comp = adjust_repair_reward(
1.0,
repair_success=True,
previous_error_codes=["PY_SYNTAX"],
new_error_codes=[],
previous_code="x =",
repaired_code="x = 1",
)
assert reward == 0.72
assert 0.0 <= reward <= MAX_REPAIR_REWARD
assert comp["repair_success"] == 1.0
assert comp["fixed_prior_errors"] == 1.0
assert comp["changed_code"] == 1.0
def test_repair_reward_penalizes_repeated_code():
changed_reward, _ = adjust_repair_reward(
1.0,
repair_success=True,
previous_error_codes=["PY_SYNTAX"],
new_error_codes=[],
previous_code="x =",
repaired_code="x = 1",
)
repeated_reward, comp = adjust_repair_reward(
1.0,
repair_success=True,
previous_error_codes=["PY_SYNTAX"],
new_error_codes=[],
previous_code="x =",
repaired_code="x =",
)
assert repeated_reward < changed_reward
assert comp["changed_code"] == 0.0
def test_repair_reward_failed_fix_stays_discounted():
reward, comp = adjust_repair_reward(
0.8,
repair_success=False,
previous_error_codes=["MB002"],
new_error_codes=["MB002"],
previous_code="x = 1",
repaired_code="x = 2",
)
assert 0.0 < reward < MAX_REPAIR_REWARD
assert comp["repair_success"] == 0.0
assert comp["fixed_prior_errors"] == 0.0
def test_normalized_episode_score_bounds():
assert normalized_episode_score(-1.0) == 0.0
assert normalized_episode_score(999.0) == 1.0
if __name__ == "__main__":
tests = [
test_ast_parses,
test_syntax_errors_are_verbose,
test_marimo_duplicate_definitions_fail_static_check,
test_marimo_runtime_rejects_numpy_math_namespace,
test_query_relevance,
test_result_novelty,
test_action_novelty_penalizes_repeated_intent,
test_research_breadth,
test_tool_choice_score,
test_source_quality,
test_coverage_delta,
test_explore_reward_integration,
test_explore_reward_empty_result_is_gated,
test_keyword_coverage,
test_format_match,
test_narration_marimo,
test_narration_manim,
test_structure_marimo,
test_marimo_structure_prefers_reactive_plot_wrappers,
test_context_usage,
test_generate_reward_garbage,
test_generate_reward_good,
test_marimo_static_failure_is_not_code_valid,
test_generate_reward_wrong_format,
test_reward_spread,
test_repair_reward_success_is_discounted_and_changed,
test_repair_reward_penalizes_repeated_code,
test_repair_reward_failed_fix_stays_discounted,
test_normalized_episode_score_bounds,
]
passed = 0
for t in tests:
try:
t()
passed += 1
except Exception as e:
print(f"FAIL: {t.__name__}: {e}")
print(f"PASS: test_rewards ({passed}/{len(tests)})")