File size: 1,338 Bytes
30cf758
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import asyncio
import unittest

from server.env import SQLDebugEnv
from server.models import SQLDebugAction, ActionType


class TestEnv(unittest.TestCase):
    def test_reset_and_inspect_schema(self):
        async def run():
            env = SQLDebugEnv(task_id="easy_syntax_fix")
            obs, info = await env.reset()
            self.assertFalse(obs.is_done)

            action = SQLDebugAction(action_type=ActionType.INSPECT_SCHEMA)
            obs2, reward, done, info2 = await env.step(action)
            self.assertFalse(done)
            self.assertIsNotNone(obs2.schema_info)
            self.assertGreaterEqual(reward, 0.0)

        asyncio.run(run())

    def test_submit_broken_query_does_not_finish(self):
        async def run():
            env = SQLDebugEnv(task_id="easy_syntax_fix")
            obs, _ = await env.reset()

            action = SQLDebugAction(
                action_type=ActionType.SUBMIT_QUERY,
                query=env.task.broken_query,
            )
            obs2, reward, done, _ = await env.step(action)

            self.assertFalse(done)
            self.assertLessEqual(reward, 0.2)
            self.assertGreaterEqual(reward, -1.0)
            self.assertEqual(obs2.current_query, env.task.broken_query)

        asyncio.run(run())


if __name__ == "__main__":
    unittest.main()