explainer-env / tests /test_environment.py
kgdrathan's picture
Upload folder using huggingface_hub
8fa7af1 verified
"""Tests for ExplainerEnvironment — multi-step explore→generate lifecycle."""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from constants import MAX_EXPLORE_STEPS, MAX_REPAIR_STEPS
from models import ExplainerAction, ExplainerObservation
from server.explainer_env_environment import ExplainerEnvironment
def test_reset_returns_observation():
env = ExplainerEnvironment()
obs = env.reset(seed=1)
assert isinstance(obs, ExplainerObservation)
assert obs.topic != ""
assert obs.phase == "explore"
assert obs.explore_steps_left == MAX_EXPLORE_STEPS
assert obs.done is False
def test_reset_deterministic_with_seed():
env = ExplainerEnvironment()
obs1 = env.reset(seed=42)
obs2 = env.reset(seed=42)
assert obs1.topic == obs2.topic
def test_explore_step():
env = ExplainerEnvironment()
env.reset(seed=1)
action = ExplainerAction(
action_type="explore",
tool="search_wikipedia",
query="gradient descent optimization",
intent="beginner explanation",
)
obs = env.step(action)
assert obs.done is False
assert obs.explore_steps_left == MAX_EXPLORE_STEPS - 1
assert isinstance(obs.reward, (int, float))
assert obs.reward >= 0.0
assert isinstance(obs.top_chunks, list)
def test_explore_empty_query():
env = ExplainerEnvironment()
env.reset(seed=1)
action = ExplainerAction(action_type="explore", tool="search_wikipedia", query="")
obs = env.step(action)
assert obs.reward == 0.0
assert "Empty query" in obs.feedback
def test_explore_max_steps():
env = ExplainerEnvironment()
env.reset(seed=1)
for i in range(MAX_EXPLORE_STEPS):
obs = env.step(ExplainerAction(
action_type="explore",
tool="search_wikipedia",
query=f"search {i}",
))
assert obs.phase == "generate"
assert obs.explore_steps_left == 0
def test_explore_then_generate():
env = ExplainerEnvironment()
env.reset(seed=1)
# Explore
obs = env.step(ExplainerAction(
action_type="explore",
tool="search_wikipedia",
query="gradient descent",
))
assert obs.done is False
assert obs.search_results != ""
# Generate
obs = env.step(ExplainerAction(
action_type="generate",
format="marimo",
code="import marimo as mo\napp = mo.App()\n@app.cell\ndef _():\n return\n",
))
assert obs.phase in ("repair", "done")
assert isinstance(obs.reward, (int, float))
def test_generate_without_explore_penalty():
env = ExplainerEnvironment()
env.reset(seed=1)
obs = env.step(ExplainerAction(
action_type="generate",
format="marimo",
code="x = 1",
))
assert obs.done is False
assert obs.phase == "repair"
assert "penalty" in obs.feedback.lower() or "without" in obs.feedback.lower()
def test_step_without_reset():
env = ExplainerEnvironment()
action = ExplainerAction(action_type="explore", tool="search_wikipedia", query="test")
obs = env.step(action)
assert obs.done is True
assert obs.reward == -1.0
def test_generate_reward_in_metadata():
env = ExplainerEnvironment()
env.reset(seed=1)
env.step(ExplainerAction(
action_type="explore",
tool="search_wikipedia",
query="gradient descent",
))
obs = env.step(ExplainerAction(
action_type="generate",
format="marimo",
code="x = 1",
))
for key in ("validity", "task_alignment", "structure", "research_usage"):
assert key in obs.metadata, f"missing {key} in metadata"
assert "explore_steps_used" in obs.metadata
def test_state_episode_id_changes():
env = ExplainerEnvironment()
env.reset()
eid1 = env.state.episode_id
env.reset()
eid2 = env.state.episode_id
assert eid1 != eid2
def test_step_increments_count():
env = ExplainerEnvironment()
env.reset(seed=1)
assert env.state.step_count == 0
env.step(ExplainerAction(
action_type="explore",
tool="search_wikipedia",
query="test",
))
assert env.state.step_count == 1
env.step(ExplainerAction(action_type="generate", format="marimo", code="x=1"))
assert env.state.step_count == 2
def test_bad_code_does_not_crash():
env = ExplainerEnvironment()
env.reset(seed=1)
obs = env.step(ExplainerAction(
action_type="generate",
format="marimo",
code=")))syntax error(((",
))
assert obs.done is False
assert obs.phase == "repair"
assert "SYNTAX ERROR" in obs.feedback
def test_failed_repair_can_continue_until_limit():
env = ExplainerEnvironment()
env.reset(seed=1)
env.step(ExplainerAction(
action_type="generate",
format="marimo",
code="x = 1",
))
obs = env.step(ExplainerAction(
action_type="repair",
format="marimo",
code="x = 2",
repair_notes="attempted fix",
))
assert obs.done is False
assert obs.phase == "repair"
assert obs.repair_attempts_left == MAX_REPAIR_STEPS - 1
assert obs.metadata["phase"] == "repair"
for attempt in range(MAX_REPAIR_STEPS - 1):
obs = env.step(ExplainerAction(
action_type="repair",
format="marimo",
code=f"x = {attempt + 3}",
repair_notes="still invalid",
))
assert obs.done is True
assert obs.phase == "done"
assert obs.repair_attempts_left == 0
if __name__ == "__main__":
tests = [
test_reset_returns_observation,
test_reset_deterministic_with_seed,
test_explore_step,
test_explore_empty_query,
test_explore_max_steps,
test_explore_then_generate,
test_generate_without_explore_penalty,
test_step_without_reset,
test_generate_reward_in_metadata,
test_state_episode_id_changes,
test_step_increments_count,
test_bad_code_does_not_crash,
test_failed_repair_can_continue_until_limit,
]
passed = 0
for t in tests:
try:
t()
passed += 1
except Exception as e:
print(f"FAIL: {t.__name__}: {e}")
print(f"PASS: test_environment ({passed}/{len(tests)})")