"""Tests for ExplainerEnvironment — multi-step explore→generate lifecycle.""" import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parents[1])) from constants import MAX_EXPLORE_STEPS, MAX_REPAIR_STEPS from models import ExplainerAction, ExplainerObservation from server.explainer_env_environment import ExplainerEnvironment def test_reset_returns_observation(): env = ExplainerEnvironment() obs = env.reset(seed=1) assert isinstance(obs, ExplainerObservation) assert obs.topic != "" assert obs.phase == "explore" assert obs.explore_steps_left == MAX_EXPLORE_STEPS assert obs.done is False def test_reset_deterministic_with_seed(): env = ExplainerEnvironment() obs1 = env.reset(seed=42) obs2 = env.reset(seed=42) assert obs1.topic == obs2.topic def test_explore_step(): env = ExplainerEnvironment() env.reset(seed=1) action = ExplainerAction( action_type="explore", tool="search_wikipedia", query="gradient descent optimization", intent="beginner explanation", ) obs = env.step(action) assert obs.done is False assert obs.explore_steps_left == MAX_EXPLORE_STEPS - 1 assert isinstance(obs.reward, (int, float)) assert obs.reward >= 0.0 assert isinstance(obs.top_chunks, list) def test_explore_empty_query(): env = ExplainerEnvironment() env.reset(seed=1) action = ExplainerAction(action_type="explore", tool="search_wikipedia", query="") obs = env.step(action) assert obs.reward == 0.0 assert "Empty query" in obs.feedback def test_explore_max_steps(): env = ExplainerEnvironment() env.reset(seed=1) for i in range(MAX_EXPLORE_STEPS): obs = env.step(ExplainerAction( action_type="explore", tool="search_wikipedia", query=f"search {i}", )) assert obs.phase == "generate" assert obs.explore_steps_left == 0 def test_explore_then_generate(): env = ExplainerEnvironment() env.reset(seed=1) # Explore obs = env.step(ExplainerAction( action_type="explore", tool="search_wikipedia", query="gradient descent", )) assert obs.done is False assert obs.search_results != "" # Generate obs = env.step(ExplainerAction( action_type="generate", format="marimo", code="import marimo as mo\napp = mo.App()\n@app.cell\ndef _():\n return\n", )) assert obs.phase in ("repair", "done") assert isinstance(obs.reward, (int, float)) def test_generate_without_explore_penalty(): env = ExplainerEnvironment() env.reset(seed=1) obs = env.step(ExplainerAction( action_type="generate", format="marimo", code="x = 1", )) assert obs.done is False assert obs.phase == "repair" assert "penalty" in obs.feedback.lower() or "without" in obs.feedback.lower() def test_step_without_reset(): env = ExplainerEnvironment() action = ExplainerAction(action_type="explore", tool="search_wikipedia", query="test") obs = env.step(action) assert obs.done is True assert obs.reward == -1.0 def test_generate_reward_in_metadata(): env = ExplainerEnvironment() env.reset(seed=1) env.step(ExplainerAction( action_type="explore", tool="search_wikipedia", query="gradient descent", )) obs = env.step(ExplainerAction( action_type="generate", format="marimo", code="x = 1", )) for key in ("validity", "task_alignment", "structure", "research_usage"): assert key in obs.metadata, f"missing {key} in metadata" assert "explore_steps_used" in obs.metadata def test_state_episode_id_changes(): env = ExplainerEnvironment() env.reset() eid1 = env.state.episode_id env.reset() eid2 = env.state.episode_id assert eid1 != eid2 def test_step_increments_count(): env = ExplainerEnvironment() env.reset(seed=1) assert env.state.step_count == 0 env.step(ExplainerAction( action_type="explore", tool="search_wikipedia", query="test", )) assert env.state.step_count == 1 env.step(ExplainerAction(action_type="generate", format="marimo", code="x=1")) assert env.state.step_count == 2 def test_bad_code_does_not_crash(): env = ExplainerEnvironment() env.reset(seed=1) obs = env.step(ExplainerAction( action_type="generate", format="marimo", code=")))syntax error(((", )) assert obs.done is False assert obs.phase == "repair" assert "SYNTAX ERROR" in obs.feedback def test_failed_repair_can_continue_until_limit(): env = ExplainerEnvironment() env.reset(seed=1) env.step(ExplainerAction( action_type="generate", format="marimo", code="x = 1", )) obs = env.step(ExplainerAction( action_type="repair", format="marimo", code="x = 2", repair_notes="attempted fix", )) assert obs.done is False assert obs.phase == "repair" assert obs.repair_attempts_left == MAX_REPAIR_STEPS - 1 assert obs.metadata["phase"] == "repair" for attempt in range(MAX_REPAIR_STEPS - 1): obs = env.step(ExplainerAction( action_type="repair", format="marimo", code=f"x = {attempt + 3}", repair_notes="still invalid", )) assert obs.done is True assert obs.phase == "done" assert obs.repair_attempts_left == 0 if __name__ == "__main__": tests = [ test_reset_returns_observation, test_reset_deterministic_with_seed, test_explore_step, test_explore_empty_query, test_explore_max_steps, test_explore_then_generate, test_generate_without_explore_penalty, test_step_without_reset, test_generate_reward_in_metadata, test_state_episode_id_changes, test_step_increments_count, test_bad_code_does_not_crash, test_failed_repair_can_continue_until_limit, ] passed = 0 for t in tests: try: t() passed += 1 except Exception as e: print(f"FAIL: {t.__name__}: {e}") print(f"PASS: test_environment ({passed}/{len(tests)})")