Spaces:
Sleeping
Sleeping
| """Tests for ExplainerEnvironment — multi-step explore→generate lifecycle.""" | |
| import sys | |
| from pathlib import Path | |
| sys.path.insert(0, str(Path(__file__).resolve().parents[1])) | |
| from constants import MAX_EXPLORE_STEPS, MAX_REPAIR_STEPS | |
| from models import ExplainerAction, ExplainerObservation | |
| from server.explainer_env_environment import ExplainerEnvironment | |
| def test_reset_returns_observation(): | |
| env = ExplainerEnvironment() | |
| obs = env.reset(seed=1) | |
| assert isinstance(obs, ExplainerObservation) | |
| assert obs.topic != "" | |
| assert obs.phase == "explore" | |
| assert obs.explore_steps_left == MAX_EXPLORE_STEPS | |
| assert obs.done is False | |
| def test_reset_deterministic_with_seed(): | |
| env = ExplainerEnvironment() | |
| obs1 = env.reset(seed=42) | |
| obs2 = env.reset(seed=42) | |
| assert obs1.topic == obs2.topic | |
| def test_explore_step(): | |
| env = ExplainerEnvironment() | |
| env.reset(seed=1) | |
| action = ExplainerAction( | |
| action_type="explore", | |
| tool="search_wikipedia", | |
| query="gradient descent optimization", | |
| intent="beginner explanation", | |
| ) | |
| obs = env.step(action) | |
| assert obs.done is False | |
| assert obs.explore_steps_left == MAX_EXPLORE_STEPS - 1 | |
| assert isinstance(obs.reward, (int, float)) | |
| assert obs.reward >= 0.0 | |
| assert isinstance(obs.top_chunks, list) | |
| def test_explore_empty_query(): | |
| env = ExplainerEnvironment() | |
| env.reset(seed=1) | |
| action = ExplainerAction(action_type="explore", tool="search_wikipedia", query="") | |
| obs = env.step(action) | |
| assert obs.reward == 0.0 | |
| assert "Empty query" in obs.feedback | |
| def test_explore_max_steps(): | |
| env = ExplainerEnvironment() | |
| env.reset(seed=1) | |
| for i in range(MAX_EXPLORE_STEPS): | |
| obs = env.step(ExplainerAction( | |
| action_type="explore", | |
| tool="search_wikipedia", | |
| query=f"search {i}", | |
| )) | |
| assert obs.phase == "generate" | |
| assert obs.explore_steps_left == 0 | |
| def test_explore_then_generate(): | |
| env = ExplainerEnvironment() | |
| env.reset(seed=1) | |
| # Explore | |
| obs = env.step(ExplainerAction( | |
| action_type="explore", | |
| tool="search_wikipedia", | |
| query="gradient descent", | |
| )) | |
| assert obs.done is False | |
| assert obs.search_results != "" | |
| # Generate | |
| obs = env.step(ExplainerAction( | |
| action_type="generate", | |
| format="marimo", | |
| code="import marimo as mo\napp = mo.App()\n@app.cell\ndef _():\n return\n", | |
| )) | |
| assert obs.phase in ("repair", "done") | |
| assert isinstance(obs.reward, (int, float)) | |
| def test_generate_without_explore_penalty(): | |
| env = ExplainerEnvironment() | |
| env.reset(seed=1) | |
| obs = env.step(ExplainerAction( | |
| action_type="generate", | |
| format="marimo", | |
| code="x = 1", | |
| )) | |
| assert obs.done is False | |
| assert obs.phase == "repair" | |
| assert "penalty" in obs.feedback.lower() or "without" in obs.feedback.lower() | |
| def test_step_without_reset(): | |
| env = ExplainerEnvironment() | |
| action = ExplainerAction(action_type="explore", tool="search_wikipedia", query="test") | |
| obs = env.step(action) | |
| assert obs.done is True | |
| assert obs.reward == -1.0 | |
| def test_generate_reward_in_metadata(): | |
| env = ExplainerEnvironment() | |
| env.reset(seed=1) | |
| env.step(ExplainerAction( | |
| action_type="explore", | |
| tool="search_wikipedia", | |
| query="gradient descent", | |
| )) | |
| obs = env.step(ExplainerAction( | |
| action_type="generate", | |
| format="marimo", | |
| code="x = 1", | |
| )) | |
| for key in ("validity", "task_alignment", "structure", "research_usage"): | |
| assert key in obs.metadata, f"missing {key} in metadata" | |
| assert "explore_steps_used" in obs.metadata | |
| def test_state_episode_id_changes(): | |
| env = ExplainerEnvironment() | |
| env.reset() | |
| eid1 = env.state.episode_id | |
| env.reset() | |
| eid2 = env.state.episode_id | |
| assert eid1 != eid2 | |
| def test_step_increments_count(): | |
| env = ExplainerEnvironment() | |
| env.reset(seed=1) | |
| assert env.state.step_count == 0 | |
| env.step(ExplainerAction( | |
| action_type="explore", | |
| tool="search_wikipedia", | |
| query="test", | |
| )) | |
| assert env.state.step_count == 1 | |
| env.step(ExplainerAction(action_type="generate", format="marimo", code="x=1")) | |
| assert env.state.step_count == 2 | |
| def test_bad_code_does_not_crash(): | |
| env = ExplainerEnvironment() | |
| env.reset(seed=1) | |
| obs = env.step(ExplainerAction( | |
| action_type="generate", | |
| format="marimo", | |
| code=")))syntax error(((", | |
| )) | |
| assert obs.done is False | |
| assert obs.phase == "repair" | |
| assert "SYNTAX ERROR" in obs.feedback | |
| def test_failed_repair_can_continue_until_limit(): | |
| env = ExplainerEnvironment() | |
| env.reset(seed=1) | |
| env.step(ExplainerAction( | |
| action_type="generate", | |
| format="marimo", | |
| code="x = 1", | |
| )) | |
| obs = env.step(ExplainerAction( | |
| action_type="repair", | |
| format="marimo", | |
| code="x = 2", | |
| repair_notes="attempted fix", | |
| )) | |
| assert obs.done is False | |
| assert obs.phase == "repair" | |
| assert obs.repair_attempts_left == MAX_REPAIR_STEPS - 1 | |
| assert obs.metadata["phase"] == "repair" | |
| for attempt in range(MAX_REPAIR_STEPS - 1): | |
| obs = env.step(ExplainerAction( | |
| action_type="repair", | |
| format="marimo", | |
| code=f"x = {attempt + 3}", | |
| repair_notes="still invalid", | |
| )) | |
| assert obs.done is True | |
| assert obs.phase == "done" | |
| assert obs.repair_attempts_left == 0 | |
| if __name__ == "__main__": | |
| tests = [ | |
| test_reset_returns_observation, | |
| test_reset_deterministic_with_seed, | |
| test_explore_step, | |
| test_explore_empty_query, | |
| test_explore_max_steps, | |
| test_explore_then_generate, | |
| test_generate_without_explore_penalty, | |
| test_step_without_reset, | |
| test_generate_reward_in_metadata, | |
| test_state_episode_id_changes, | |
| test_step_increments_count, | |
| test_bad_code_does_not_crash, | |
| test_failed_repair_can_continue_until_limit, | |
| ] | |
| passed = 0 | |
| for t in tests: | |
| try: | |
| t() | |
| passed += 1 | |
| except Exception as e: | |
| print(f"FAIL: {t.__name__}: {e}") | |
| print(f"PASS: test_environment ({passed}/{len(tests)})") | |