Spaces:
Sleeping
Sleeping
File size: 4,546 Bytes
3c1b0c7 8cb206e 3c1b0c7 8778707 3c1b0c7 8778707 3c1b0c7 eca453d 3c1b0c7 eca453d 3c1b0c7 8cb206e 3c1b0c7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | import pytest
from env.environment import SQLDebuggerEnvironment
from env.models import Action, ActionType, DifficultyLevel
@pytest.fixture
def env():
e = SQLDebuggerEnvironment()
return e
def test_state_before_reset(env):
"""state() before reset must not crash — returns default state."""
s = env.state()
assert s.initialized == False
assert s.step_count == 0
def test_reset_easy(env):
obs = env.reset(difficulty="easy")
assert obs.task_id.startswith("easy")
assert obs.step_count == 0
assert obs.difficulty == DifficultyLevel.EASY
assert "fixed_query" not in obs.current_context
assert "buggy_query" in obs.current_context or "slow_queries" in obs.current_context
def test_reset_medium(env):
obs = env.reset(difficulty="medium")
assert obs.task_id.startswith("medium")
def test_reset_hard(env):
obs = env.reset(difficulty="hard")
assert obs.task_id.startswith("hard")
def test_reset_clears_state(env):
"""Reset mid-episode must clear all state — no leakage."""
env.reset(difficulty="easy")
action = Action(action_type=ActionType.IDENTIFY_ERROR,
payload={"error_location": "SELECT", "error_type": "syntax"})
env.step(action)
assert env.state().step_count == 1
# Reset mid-episode
env.reset(difficulty="medium")
assert env.state().step_count == 0
assert env.state().total_reward == 0.0
assert env.state().previous_actions == []
def test_step_identify_error(env):
# Use Round 1 task which has no DB simulator target to hit
env.reset(difficulty="easy", task_id="easy_001")
action = Action(action_type=ActionType.IDENTIFY_ERROR,
payload={"error_location": "SELECT clause", "error_type": "syntax",
"explanation": "Missing commas"})
resp = env.step(action)
assert resp.reward.score > 0
assert resp.done == False # identify_error is not terminal
def test_step_null_action(env):
"""Null action must return 0.0 or greater, never crash."""
env.reset(difficulty="easy")
resp = env.step(None)
assert resp.reward.score >= 0.0
assert resp.done == False
def test_step_after_done(env):
"""Step after done must not crash."""
env.reset(difficulty="easy", task_id="easy_001")
action = Action(action_type=ActionType.SUBMIT_ANSWER,
payload={"fixed_query": "SELECT id, name, email FROM users WHERE active = 1",
"explanation": "Fixed", "confidence": 0.9})
env.step(action)
assert env.state().done == True
# Step again after done
resp = env.step(action)
assert resp.done == True
assert "Call reset()" in resp.reward.feedback
def test_dense_reward(env):
"""Reward must vary at each step — not only at end."""
env.reset(difficulty="easy")
rewards = []
actions = [
Action(action_type=ActionType.IDENTIFY_ERROR,
payload={"error_location": "SELECT", "error_type": "syntax"}),
Action(action_type=ActionType.EXPLAIN_ISSUE,
payload={"explanation": "Missing commas between column names in SELECT"}),
]
for a in actions:
r = env.step(a)
rewards.append(r.reward.score)
if r.done:
break
# Rewards must not all be zero
assert any(r != 0.0 for r in rewards)
def test_max_steps(env):
"""Episode must terminate at max_steps."""
env.reset(difficulty="easy")
action = Action(action_type=ActionType.IDENTIFY_ERROR,
payload={"error_location": "x", "error_type": "syntax"})
done = False
for _ in range(55):
resp = env.step(action)
if resp.done:
done = True
break
assert done == True
def test_hint_injected_in_context(env):
"""Hint must appear in next observation after request_hint."""
env.reset(difficulty="easy")
action = Action(action_type=ActionType.REQUEST_HINT,
payload={"hint_type": "location"})
resp = env.step(action)
assert "last_hint" in resp.observation.current_context
def test_state_reflects_latest_step(env):
"""state() must always reflect the latest step accurately."""
env.reset(difficulty="easy")
action = Action(action_type=ActionType.IDENTIFY_ERROR,
payload={"error_location": "SELECT", "error_type": "syntax"})
env.step(action)
s = env.state()
assert s.step_count == 1
assert s.initialized == True
assert "identify_error" in s.previous_actions |