sql-db-engineer-agent / tests /test_environment.py
junaid0600's picture
Update tests/test_environment.py
eca453d verified
Raw
History Blame Contribute Delete
4.55 kB
import pytest
from env.environment import SQLDebuggerEnvironment
from env.models import Action, ActionType, DifficultyLevel
@pytest.fixture
def env():
e = SQLDebuggerEnvironment()
return e
def test_state_before_reset(env):
"""state() before reset must not crash — returns default state."""
s = env.state()
assert s.initialized == False
assert s.step_count == 0
def test_reset_easy(env):
obs = env.reset(difficulty="easy")
assert obs.task_id.startswith("easy")
assert obs.step_count == 0
assert obs.difficulty == DifficultyLevel.EASY
assert "fixed_query" not in obs.current_context
assert "buggy_query" in obs.current_context or "slow_queries" in obs.current_context
def test_reset_medium(env):
obs = env.reset(difficulty="medium")
assert obs.task_id.startswith("medium")
def test_reset_hard(env):
obs = env.reset(difficulty="hard")
assert obs.task_id.startswith("hard")
def test_reset_clears_state(env):
"""Reset mid-episode must clear all state — no leakage."""
env.reset(difficulty="easy")
action = Action(action_type=ActionType.IDENTIFY_ERROR,
payload={"error_location": "SELECT", "error_type": "syntax"})
env.step(action)
assert env.state().step_count == 1
# Reset mid-episode
env.reset(difficulty="medium")
assert env.state().step_count == 0
assert env.state().total_reward == 0.0
assert env.state().previous_actions == []
def test_step_identify_error(env):
# Use Round 1 task which has no DB simulator target to hit
env.reset(difficulty="easy", task_id="easy_001")
action = Action(action_type=ActionType.IDENTIFY_ERROR,
payload={"error_location": "SELECT clause", "error_type": "syntax",
"explanation": "Missing commas"})
resp = env.step(action)
assert resp.reward.score > 0
assert resp.done == False # identify_error is not terminal
def test_step_null_action(env):
"""Null action must return 0.0 or greater, never crash."""
env.reset(difficulty="easy")
resp = env.step(None)
assert resp.reward.score >= 0.0
assert resp.done == False
def test_step_after_done(env):
"""Step after done must not crash."""
env.reset(difficulty="easy", task_id="easy_001")
action = Action(action_type=ActionType.SUBMIT_ANSWER,
payload={"fixed_query": "SELECT id, name, email FROM users WHERE active = 1",
"explanation": "Fixed", "confidence": 0.9})
env.step(action)
assert env.state().done == True
# Step again after done
resp = env.step(action)
assert resp.done == True
assert "Call reset()" in resp.reward.feedback
def test_dense_reward(env):
"""Reward must vary at each step — not only at end."""
env.reset(difficulty="easy")
rewards = []
actions = [
Action(action_type=ActionType.IDENTIFY_ERROR,
payload={"error_location": "SELECT", "error_type": "syntax"}),
Action(action_type=ActionType.EXPLAIN_ISSUE,
payload={"explanation": "Missing commas between column names in SELECT"}),
]
for a in actions:
r = env.step(a)
rewards.append(r.reward.score)
if r.done:
break
# Rewards must not all be zero
assert any(r != 0.0 for r in rewards)
def test_max_steps(env):
"""Episode must terminate at max_steps."""
env.reset(difficulty="easy")
action = Action(action_type=ActionType.IDENTIFY_ERROR,
payload={"error_location": "x", "error_type": "syntax"})
done = False
for _ in range(55):
resp = env.step(action)
if resp.done:
done = True
break
assert done == True
def test_hint_injected_in_context(env):
"""Hint must appear in next observation after request_hint."""
env.reset(difficulty="easy")
action = Action(action_type=ActionType.REQUEST_HINT,
payload={"hint_type": "location"})
resp = env.step(action)
assert "last_hint" in resp.observation.current_context
def test_state_reflects_latest_step(env):
"""state() must always reflect the latest step accurately."""
env.reset(difficulty="easy")
action = Action(action_type=ActionType.IDENTIFY_ERROR,
payload={"error_location": "SELECT", "error_type": "syntax"})
env.step(action)
s = env.state()
assert s.step_count == 1
assert s.initialized == True
assert "identify_error" in s.previous_actions