File size: 4,546 Bytes
3c1b0c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8cb206e
3c1b0c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8778707
 
3c1b0c7
 
 
 
 
8778707
3c1b0c7
 
eca453d
3c1b0c7
 
eca453d
3c1b0c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8cb206e
3c1b0c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import pytest
from env.environment import SQLDebuggerEnvironment
from env.models import Action, ActionType, DifficultyLevel


@pytest.fixture
def env():
    e = SQLDebuggerEnvironment()
    return e


def test_state_before_reset(env):
    """state() before reset must not crash — returns default state."""
    s = env.state()
    assert s.initialized == False
    assert s.step_count == 0


def test_reset_easy(env):
    obs = env.reset(difficulty="easy")
    assert obs.task_id.startswith("easy")
    assert obs.step_count == 0
    assert obs.difficulty == DifficultyLevel.EASY
    assert "fixed_query" not in obs.current_context
    assert "buggy_query" in obs.current_context or "slow_queries" in obs.current_context

def test_reset_medium(env):
    obs = env.reset(difficulty="medium")
    assert obs.task_id.startswith("medium")


def test_reset_hard(env):
    obs = env.reset(difficulty="hard")
    assert obs.task_id.startswith("hard")


def test_reset_clears_state(env):
    """Reset mid-episode must clear all state — no leakage."""
    env.reset(difficulty="easy")
    action = Action(action_type=ActionType.IDENTIFY_ERROR,
                    payload={"error_location": "SELECT", "error_type": "syntax"})
    env.step(action)
    assert env.state().step_count == 1

    # Reset mid-episode
    env.reset(difficulty="medium")
    assert env.state().step_count == 0
    assert env.state().total_reward == 0.0
    assert env.state().previous_actions == []


def test_step_identify_error(env):
    # Use Round 1 task which has no DB simulator target to hit
    env.reset(difficulty="easy", task_id="easy_001")
    action = Action(action_type=ActionType.IDENTIFY_ERROR,
                    payload={"error_location": "SELECT clause", "error_type": "syntax",
                             "explanation": "Missing commas"})
    resp = env.step(action)
    assert resp.reward.score > 0
    assert resp.done == False  # identify_error is not terminal

def test_step_null_action(env):
    """Null action must return 0.0 or greater, never crash."""
    env.reset(difficulty="easy")
    resp = env.step(None)
    assert resp.reward.score >= 0.0
    assert resp.done == False


def test_step_after_done(env):
    """Step after done must not crash."""
    env.reset(difficulty="easy", task_id="easy_001")
    action = Action(action_type=ActionType.SUBMIT_ANSWER,
                    payload={"fixed_query": "SELECT id, name, email FROM users WHERE active = 1",
                             "explanation": "Fixed", "confidence": 0.9})
    env.step(action)
    assert env.state().done == True

    # Step again after done
    resp = env.step(action)
    assert resp.done == True
    assert "Call reset()" in resp.reward.feedback


def test_dense_reward(env):
    """Reward must vary at each step — not only at end."""
    env.reset(difficulty="easy")
    rewards = []
    actions = [
        Action(action_type=ActionType.IDENTIFY_ERROR,
               payload={"error_location": "SELECT", "error_type": "syntax"}),
        Action(action_type=ActionType.EXPLAIN_ISSUE,
               payload={"explanation": "Missing commas between column names in SELECT"}),
    ]
    for a in actions:
        r = env.step(a)
        rewards.append(r.reward.score)
        if r.done:
            break

    # Rewards must not all be zero
    assert any(r != 0.0 for r in rewards)


def test_max_steps(env):
    """Episode must terminate at max_steps."""
    env.reset(difficulty="easy")
    action = Action(action_type=ActionType.IDENTIFY_ERROR,
                    payload={"error_location": "x", "error_type": "syntax"})
    done = False
    for _ in range(55):
        resp = env.step(action)
        if resp.done:
            done = True
            break
    assert done == True


def test_hint_injected_in_context(env):
    """Hint must appear in next observation after request_hint."""
    env.reset(difficulty="easy")
    action = Action(action_type=ActionType.REQUEST_HINT,
                    payload={"hint_type": "location"})
    resp = env.step(action)
    assert "last_hint" in resp.observation.current_context


def test_state_reflects_latest_step(env):
    """state() must always reflect the latest step accurately."""
    env.reset(difficulty="easy")
    action = Action(action_type=ActionType.IDENTIFY_ERROR,
                    payload={"error_location": "SELECT", "error_type": "syntax"})
    env.step(action)
    s = env.state()
    assert s.step_count == 1
    assert s.initialized == True
    assert "identify_error" in s.previous_actions