Spaces:
Running
Running
File size: 6,324 Bytes
eb1ebe6 8fa7af1 eb1ebe6 8fa7af1 eb1ebe6 43f41de eb1ebe6 8fa7af1 eb1ebe6 8fa7af1 eb1ebe6 43f41de eb1ebe6 8fa7af1 43f41de eb1ebe6 43f41de eb1ebe6 43f41de eb1ebe6 43f41de eb1ebe6 43f41de eb1ebe6 43f41de eb1ebe6 43f41de eb1ebe6 8fa7af1 eb1ebe6 43f41de eb1ebe6 43f41de eb1ebe6 8fa7af1 43f41de 8fa7af1 43f41de 8fa7af1 43f41de eb1ebe6 8fa7af1 eb1ebe6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 | """Tests for ExplainerEnvironment — multi-step explore→generate lifecycle."""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from constants import MAX_EXPLORE_STEPS, MAX_REPAIR_STEPS
from models import ExplainerAction, ExplainerObservation
from server.explainer_env_environment import ExplainerEnvironment
def test_reset_returns_observation():
env = ExplainerEnvironment()
obs = env.reset(seed=1)
assert isinstance(obs, ExplainerObservation)
assert obs.topic != ""
assert obs.phase == "explore"
assert obs.explore_steps_left == MAX_EXPLORE_STEPS
assert obs.done is False
def test_reset_deterministic_with_seed():
env = ExplainerEnvironment()
obs1 = env.reset(seed=42)
obs2 = env.reset(seed=42)
assert obs1.topic == obs2.topic
def test_explore_step():
env = ExplainerEnvironment()
env.reset(seed=1)
action = ExplainerAction(
action_type="explore",
tool="search_wikipedia",
query="gradient descent optimization",
intent="beginner explanation",
)
obs = env.step(action)
assert obs.done is False
assert obs.explore_steps_left == MAX_EXPLORE_STEPS - 1
assert isinstance(obs.reward, (int, float))
assert obs.reward >= 0.0
assert isinstance(obs.top_chunks, list)
def test_explore_empty_query():
env = ExplainerEnvironment()
env.reset(seed=1)
action = ExplainerAction(action_type="explore", tool="search_wikipedia", query="")
obs = env.step(action)
assert obs.reward == 0.0
assert "Empty query" in obs.feedback
def test_explore_max_steps():
env = ExplainerEnvironment()
env.reset(seed=1)
for i in range(MAX_EXPLORE_STEPS):
obs = env.step(ExplainerAction(
action_type="explore",
tool="search_wikipedia",
query=f"search {i}",
))
assert obs.phase == "generate"
assert obs.explore_steps_left == 0
def test_explore_then_generate():
env = ExplainerEnvironment()
env.reset(seed=1)
# Explore
obs = env.step(ExplainerAction(
action_type="explore",
tool="search_wikipedia",
query="gradient descent",
))
assert obs.done is False
assert obs.search_results != ""
# Generate
obs = env.step(ExplainerAction(
action_type="generate",
format="marimo",
code="import marimo as mo\napp = mo.App()\n@app.cell\ndef _():\n return\n",
))
assert obs.phase in ("repair", "done")
assert isinstance(obs.reward, (int, float))
def test_generate_without_explore_penalty():
env = ExplainerEnvironment()
env.reset(seed=1)
obs = env.step(ExplainerAction(
action_type="generate",
format="marimo",
code="x = 1",
))
assert obs.done is False
assert obs.phase == "repair"
assert "penalty" in obs.feedback.lower() or "without" in obs.feedback.lower()
def test_step_without_reset():
env = ExplainerEnvironment()
action = ExplainerAction(action_type="explore", tool="search_wikipedia", query="test")
obs = env.step(action)
assert obs.done is True
assert obs.reward == -1.0
def test_generate_reward_in_metadata():
env = ExplainerEnvironment()
env.reset(seed=1)
env.step(ExplainerAction(
action_type="explore",
tool="search_wikipedia",
query="gradient descent",
))
obs = env.step(ExplainerAction(
action_type="generate",
format="marimo",
code="x = 1",
))
for key in ("validity", "task_alignment", "structure", "research_usage"):
assert key in obs.metadata, f"missing {key} in metadata"
assert "explore_steps_used" in obs.metadata
def test_state_episode_id_changes():
env = ExplainerEnvironment()
env.reset()
eid1 = env.state.episode_id
env.reset()
eid2 = env.state.episode_id
assert eid1 != eid2
def test_step_increments_count():
env = ExplainerEnvironment()
env.reset(seed=1)
assert env.state.step_count == 0
env.step(ExplainerAction(
action_type="explore",
tool="search_wikipedia",
query="test",
))
assert env.state.step_count == 1
env.step(ExplainerAction(action_type="generate", format="marimo", code="x=1"))
assert env.state.step_count == 2
def test_bad_code_does_not_crash():
env = ExplainerEnvironment()
env.reset(seed=1)
obs = env.step(ExplainerAction(
action_type="generate",
format="marimo",
code=")))syntax error(((",
))
assert obs.done is False
assert obs.phase == "repair"
assert "SYNTAX ERROR" in obs.feedback
def test_failed_repair_can_continue_until_limit():
env = ExplainerEnvironment()
env.reset(seed=1)
env.step(ExplainerAction(
action_type="generate",
format="marimo",
code="x = 1",
))
obs = env.step(ExplainerAction(
action_type="repair",
format="marimo",
code="x = 2",
repair_notes="attempted fix",
))
assert obs.done is False
assert obs.phase == "repair"
assert obs.repair_attempts_left == MAX_REPAIR_STEPS - 1
assert obs.metadata["phase"] == "repair"
for attempt in range(MAX_REPAIR_STEPS - 1):
obs = env.step(ExplainerAction(
action_type="repair",
format="marimo",
code=f"x = {attempt + 3}",
repair_notes="still invalid",
))
assert obs.done is True
assert obs.phase == "done"
assert obs.repair_attempts_left == 0
if __name__ == "__main__":
tests = [
test_reset_returns_observation,
test_reset_deterministic_with_seed,
test_explore_step,
test_explore_empty_query,
test_explore_max_steps,
test_explore_then_generate,
test_generate_without_explore_penalty,
test_step_without_reset,
test_generate_reward_in_metadata,
test_state_episode_id_changes,
test_step_increments_count,
test_bad_code_does_not_crash,
test_failed_repair_can_continue_until_limit,
]
passed = 0
for t in tests:
try:
t()
passed += 1
except Exception as e:
print(f"FAIL: {t.__name__}: {e}")
print(f"PASS: test_environment ({passed}/{len(tests)})")
|