snigenigmatic commited on
Commit
2e6e0b2
·
verified ·
1 Parent(s): 0683cf4

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. server/sql_environment.py +35 -9
server/sql_environment.py CHANGED
@@ -1,8 +1,8 @@
1
  import sqlite3
2
  import random
3
- from typing import Tuple
4
 
5
- from openenv.core.env_server.types import EnvBase
6
  from models import SQLAction, SQLObservation, SQLState
7
  from server.challenges import CHALLENGES
8
 
@@ -47,9 +47,22 @@ def _results_match(schema_sql: str, query_a: str, query_b: str) -> bool:
47
  return False
48
 
49
 
50
- class SQLTutorEnvironment(EnvBase[SQLAction, SQLObservation, SQLState]):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- def reset(self) -> Tuple[SQLObservation, SQLState]:
53
  challenge = random.choice(CHALLENGES)
54
 
55
  state = SQLState(
@@ -65,6 +78,8 @@ class SQLTutorEnvironment(EnvBase[SQLAction, SQLObservation, SQLState]):
65
  hints_used=0,
66
  is_resolved=False,
67
  cumulative_reward=0.0,
 
 
68
  )
69
  self._state = state
70
 
@@ -81,12 +96,20 @@ class SQLTutorEnvironment(EnvBase[SQLAction, SQLObservation, SQLState]):
81
  steps_taken=0,
82
  max_steps=state.max_steps,
83
  hints_used=0,
 
 
84
  )
85
- return observation, state
86
-
87
- def step(self, action: SQLAction) -> Tuple[SQLObservation, float, bool, SQLState]:
 
 
 
 
 
88
  state = self._state
89
  state.steps_taken += 1
 
90
  reward = 0.0
91
  done = False
92
  hint = None
@@ -147,9 +170,12 @@ class SQLTutorEnvironment(EnvBase[SQLAction, SQLObservation, SQLState]):
147
  steps_taken=state.steps_taken,
148
  max_steps=state.max_steps,
149
  hints_used=state.hints_used,
 
 
150
  )
151
 
152
- return observation, reward, done, state
153
 
154
- def get_state(self) -> SQLState:
 
155
  return self._state
 
1
  import sqlite3
2
  import random
3
+ from typing import Any, Optional, Tuple
4
 
5
+ from openenv.core.env_server.interfaces import Environment
6
  from models import SQLAction, SQLObservation, SQLState
7
  from server.challenges import CHALLENGES
8
 
 
47
  return False
48
 
49
 
50
+ class SQLTutorEnvironment(Environment[SQLAction, SQLObservation, SQLState]):
51
+ SUPPORTS_CONCURRENT_SESSIONS = True
52
+
53
+ def __init__(self):
54
+ super().__init__()
55
+ self._state = SQLState()
56
+
57
+ def reset(
58
+ self,
59
+ seed: Optional[int] = None,
60
+ episode_id: Optional[str] = None,
61
+ **kwargs: Any,
62
+ ) -> SQLObservation:
63
+ if seed is not None:
64
+ random.seed(seed)
65
 
 
66
  challenge = random.choice(CHALLENGES)
67
 
68
  state = SQLState(
 
78
  hints_used=0,
79
  is_resolved=False,
80
  cumulative_reward=0.0,
81
+ episode_id=episode_id,
82
+ step_count=0,
83
  )
84
  self._state = state
85
 
 
96
  steps_taken=0,
97
  max_steps=state.max_steps,
98
  hints_used=0,
99
+ done=False,
100
+ reward=None,
101
  )
102
+ return observation
103
+
104
+ def step(
105
+ self,
106
+ action: SQLAction,
107
+ timeout_s: Optional[float] = None,
108
+ **kwargs: Any,
109
+ ) -> SQLObservation:
110
  state = self._state
111
  state.steps_taken += 1
112
+ state.step_count += 1
113
  reward = 0.0
114
  done = False
115
  hint = None
 
170
  steps_taken=state.steps_taken,
171
  max_steps=state.max_steps,
172
  hints_used=state.hints_used,
173
+ done=done,
174
+ reward=reward,
175
  )
176
 
177
+ return observation
178
 
179
+ @property
180
+ def state(self) -> SQLState:
181
  return self._state