File size: 4,366 Bytes
72805b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac49ad8
72805b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""Tests for SQL Arena environment."""

import pytest
from src.sql_arena.environment import SQLArenaEnvironment
from src.sql_arena.models import SQLArenaAction
from src.sql_arena.tasks import list_tasks, TASK_BY_ID


class TestEnvironmentBasics:

    def setup_method(self):
        self.env = SQLArenaEnvironment()

    def teardown_method(self):
        self.env.close()

    def test_reset_returns_observation(self):
        result = self.env.reset(difficulty="basic_select", task_id="easy_001")
        assert result.observation is not None
        assert result.reward == 0.0
        assert result.done is False

    def test_step_with_correct_query(self):
        self.env.reset(difficulty="basic_select", task_id="easy_001")
        task = self.env.current_task
        action = SQLArenaAction(sql_query=task.expected_sql)
        result = self.env.step(action)
        assert result.reward > 0.0
        assert result.info.get("score", 0) >= 0.8

    def test_step_with_invalid_query(self):
        self.env.reset(difficulty="basic_select", task_id="easy_001")
        action = SQLArenaAction(sql_query="INVALID SQL QUERY")
        result = self.env.step(action)
        assert result.reward == 0.01  # Clamped to strictly > 0
        assert result.observation.error_message is not None

    def test_state_tracking(self):
        self.env.reset(difficulty="basic_select", task_id="easy_001")
        state = self.env.state()
        assert state.current_step == 0

        self.env.step(SQLArenaAction(sql_query="SELECT 1"))
        state = self.env.state()
        assert state.current_step == 1

    def test_episode_terminates(self):
        self.env.reset(difficulty="basic_select", task_id="easy_001")
        task = self.env.current_task
        for _ in range(task.max_steps + 1):
            if self.env.state().done:
                break
            self.env.step(SQLArenaAction(sql_query="SELECT 1"))
        assert self.env.state().done is True


class TestAllDifficulties:

    def setup_method(self):
        self.env = SQLArenaEnvironment()

    def teardown_method(self):
        self.env.close()

    def test_easy(self):
        result = self.env.reset(difficulty="basic_select")
        assert result.observation.difficulty == "basic_select"

    def test_medium(self):
        result = self.env.reset(difficulty="join_aggregate")
        assert result.observation.difficulty == "join_aggregate"

    def test_hard(self):
        result = self.env.reset(difficulty="complex_analysis")
        assert result.observation.difficulty == "complex_analysis"


class TestGrading:

    def setup_method(self):
        self.env = SQLArenaEnvironment()

    def teardown_method(self):
        self.env.close()

    def test_scores_in_range(self):
        for task_id, task in TASK_BY_ID.items():
            self.env.reset(difficulty=task.difficulty, task_id=task_id)
            action = SQLArenaAction(sql_query=task.expected_sql)
            result = self.env.step(action)
            assert 0.0 <= result.reward <= 1.0
            assert 0.0 <= result.info.get("score", 0) <= 1.0
            self.env.reset(difficulty=task.difficulty, task_id=task_id)

    def test_varying_scores(self):
        scores = set()
        queries = [
            "SELECT name, salary FROM employees WHERE is_active = 1 AND salary > 80000 ORDER BY salary DESC",
            "SELECT * FROM employees",
            "INVALID",
            "SELECT name FROM employees",
        ]
        for q in queries:
            self.env.reset(difficulty="basic_select", task_id="easy_001")
            result = self.env.step(SQLArenaAction(sql_query=q))
            scores.add(round(result.info.get("score", 0), 2))
        assert len(scores) > 1, "Grader always returns the same score!"


class TestTaskRegistry:

    def test_list_tasks(self):
        tasks = list_tasks()
        assert "basic_select" in tasks
        assert "join_aggregate" in tasks
        assert "complex_analysis" in tasks

    def test_minimum_3_tasks(self):
        tasks = list_tasks()
        for difficulty, task_ids in tasks.items():
            assert len(task_ids) >= 3, f"{difficulty} has fewer than 3 tasks"


if __name__ == "__main__":
    pytest.main([__file__, "-v"])