Spaces:
Running
Running
File size: 1,338 Bytes
30cf758 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 | import asyncio
import unittest
from server.env import SQLDebugEnv
from server.models import SQLDebugAction, ActionType
class TestEnv(unittest.TestCase):
def test_reset_and_inspect_schema(self):
async def run():
env = SQLDebugEnv(task_id="easy_syntax_fix")
obs, info = await env.reset()
self.assertFalse(obs.is_done)
action = SQLDebugAction(action_type=ActionType.INSPECT_SCHEMA)
obs2, reward, done, info2 = await env.step(action)
self.assertFalse(done)
self.assertIsNotNone(obs2.schema_info)
self.assertGreaterEqual(reward, 0.0)
asyncio.run(run())
def test_submit_broken_query_does_not_finish(self):
async def run():
env = SQLDebugEnv(task_id="easy_syntax_fix")
obs, _ = await env.reset()
action = SQLDebugAction(
action_type=ActionType.SUBMIT_QUERY,
query=env.task.broken_query,
)
obs2, reward, done, _ = await env.step(action)
self.assertFalse(done)
self.assertLessEqual(reward, 0.2)
self.assertGreaterEqual(reward, -1.0)
self.assertEqual(obs2.current_query, env.task.broken_query)
asyncio.run(run())
if __name__ == "__main__":
unittest.main()
|