Spaces:
Sleeping
Sleeping
| import pytest | |
| from env.environment import SQLDebuggerEnvironment | |
| from env.models import Action, ActionType, DifficultyLevel | |
| def env(): | |
| e = SQLDebuggerEnvironment() | |
| return e | |
| def test_state_before_reset(env): | |
| """state() before reset must not crash — returns default state.""" | |
| s = env.state() | |
| assert s.initialized == False | |
| assert s.step_count == 0 | |
| def test_reset_easy(env): | |
| obs = env.reset(difficulty="easy") | |
| assert obs.task_id.startswith("easy") | |
| assert obs.step_count == 0 | |
| assert obs.difficulty == DifficultyLevel.EASY | |
| assert "fixed_query" not in obs.current_context | |
| assert "buggy_query" in obs.current_context or "slow_queries" in obs.current_context | |
| def test_reset_medium(env): | |
| obs = env.reset(difficulty="medium") | |
| assert obs.task_id.startswith("medium") | |
| def test_reset_hard(env): | |
| obs = env.reset(difficulty="hard") | |
| assert obs.task_id.startswith("hard") | |
| def test_reset_clears_state(env): | |
| """Reset mid-episode must clear all state — no leakage.""" | |
| env.reset(difficulty="easy") | |
| action = Action(action_type=ActionType.IDENTIFY_ERROR, | |
| payload={"error_location": "SELECT", "error_type": "syntax"}) | |
| env.step(action) | |
| assert env.state().step_count == 1 | |
| # Reset mid-episode | |
| env.reset(difficulty="medium") | |
| assert env.state().step_count == 0 | |
| assert env.state().total_reward == 0.0 | |
| assert env.state().previous_actions == [] | |
| def test_step_identify_error(env): | |
| # Use Round 1 task which has no DB simulator target to hit | |
| env.reset(difficulty="easy", task_id="easy_001") | |
| action = Action(action_type=ActionType.IDENTIFY_ERROR, | |
| payload={"error_location": "SELECT clause", "error_type": "syntax", | |
| "explanation": "Missing commas"}) | |
| resp = env.step(action) | |
| assert resp.reward.score > 0 | |
| assert resp.done == False # identify_error is not terminal | |
| def test_step_null_action(env): | |
| """Null action must return 0.0 or greater, never crash.""" | |
| env.reset(difficulty="easy") | |
| resp = env.step(None) | |
| assert resp.reward.score >= 0.0 | |
| assert resp.done == False | |
| def test_step_after_done(env): | |
| """Step after done must not crash.""" | |
| env.reset(difficulty="easy", task_id="easy_001") | |
| action = Action(action_type=ActionType.SUBMIT_ANSWER, | |
| payload={"fixed_query": "SELECT id, name, email FROM users WHERE active = 1", | |
| "explanation": "Fixed", "confidence": 0.9}) | |
| env.step(action) | |
| assert env.state().done == True | |
| # Step again after done | |
| resp = env.step(action) | |
| assert resp.done == True | |
| assert "Call reset()" in resp.reward.feedback | |
| def test_dense_reward(env): | |
| """Reward must vary at each step — not only at end.""" | |
| env.reset(difficulty="easy") | |
| rewards = [] | |
| actions = [ | |
| Action(action_type=ActionType.IDENTIFY_ERROR, | |
| payload={"error_location": "SELECT", "error_type": "syntax"}), | |
| Action(action_type=ActionType.EXPLAIN_ISSUE, | |
| payload={"explanation": "Missing commas between column names in SELECT"}), | |
| ] | |
| for a in actions: | |
| r = env.step(a) | |
| rewards.append(r.reward.score) | |
| if r.done: | |
| break | |
| # Rewards must not all be zero | |
| assert any(r != 0.0 for r in rewards) | |
| def test_max_steps(env): | |
| """Episode must terminate at max_steps.""" | |
| env.reset(difficulty="easy") | |
| action = Action(action_type=ActionType.IDENTIFY_ERROR, | |
| payload={"error_location": "x", "error_type": "syntax"}) | |
| done = False | |
| for _ in range(55): | |
| resp = env.step(action) | |
| if resp.done: | |
| done = True | |
| break | |
| assert done == True | |
| def test_hint_injected_in_context(env): | |
| """Hint must appear in next observation after request_hint.""" | |
| env.reset(difficulty="easy") | |
| action = Action(action_type=ActionType.REQUEST_HINT, | |
| payload={"hint_type": "location"}) | |
| resp = env.step(action) | |
| assert "last_hint" in resp.observation.current_context | |
| def test_state_reflects_latest_step(env): | |
| """state() must always reflect the latest step accurately.""" | |
| env.reset(difficulty="easy") | |
| action = Action(action_type=ActionType.IDENTIFY_ERROR, | |
| payload={"error_location": "SELECT", "error_type": "syntax"}) | |
| env.step(action) | |
| s = env.state() | |
| assert s.step_count == 1 | |
| assert s.initialized == True | |
| assert "identify_error" in s.previous_actions |