jaivardhan2409 commited on
Commit
126939a
·
verified ·
1 Parent(s): aeea577

Upload folder using huggingface_hub

Browse files
Files changed (6) hide show
  1. baseline.py +37 -19
  2. env/environment.py +72 -47
  3. env/models.py +25 -17
  4. inference.py +89 -0
  5. models.py +25 -17
  6. server/app.py +15 -2
baseline.py CHANGED
@@ -1,32 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  from openai import OpenAI
3
  from env.environment import SQLEnv
4
  from env.models import Action
5
 
 
6
  def run_task(env: SQLEnv, task_id: int) -> float:
7
  client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
8
- obs = env.reset(task_id)
9
-
10
  messages = [
11
- {"role": "system", "content": "You are an expert SQL DBA. You rewrite SQL queries to be correct, optimized, and performant."}
 
 
 
 
 
 
12
  ]
13
-
14
- prompt = f"""
15
- Task # {obs.task_id}
16
  Original Query: {obs.query}
17
  Database Schema Context: {obs.schema_context}
18
  Hint: {obs.hint}
19
 
20
- Please provide the optimized query. Output ONLY the raw SQL query, no markdown formatting, no explanation.
21
- """
22
-
23
  messages.append({"role": "user", "content": prompt.strip()})
24
-
25
  try:
26
  response = client.chat.completions.create(
27
  model="gpt-3.5-turbo",
28
  messages=messages,
29
- temperature=0.0
30
  )
31
  rewritten_query = response.choices[0].message.content.strip()
32
  if rewritten_query.startswith("```sql"):
@@ -37,30 +53,32 @@ Please provide the optimized query. Output ONLY the raw SQL query, no markdown f
37
  except Exception as e:
38
  print(f"Error calling OpenAI API: {e}")
39
  rewritten_query = obs.query
40
-
41
  action = Action(
42
  rewritten_query=rewritten_query,
43
  explanation="Baseline inference using LLM",
44
- is_done=True
45
  )
46
-
47
- _, reward, done, info = env.step(action)
48
- return env.final_grader_score
 
49
 
50
  def run_all_tasks():
51
  if not os.environ.get("OPENAI_API_KEY"):
52
  raise ValueError("OPENAI_API_KEY environment variable is required.")
53
-
54
  env = SQLEnv()
55
  scores = {}
56
  for task_id in [1, 2, 3]:
57
  print(f"Running baseline for Task {task_id}...")
58
  score = run_task(env, task_id)
59
  scores[task_id] = score
60
- print(f"Task {task_id} Grader Score: {score}")
61
-
62
  return scores
63
 
 
64
  if __name__ == "__main__":
65
  try:
66
  scores = run_all_tasks()
 
1
+ """
2
+ Baseline inference script for the SQL Query Optimizer OpenEnv.
3
+
4
+ Uses the OpenAI API client to run a model against the environment
5
+ and produce reproducible baseline scores on all 3 tasks.
6
+
7
+ Usage:
8
+ export OPENAI_API_KEY=sk-...
9
+ python baseline.py
10
+ """
11
+
12
  import os
13
  from openai import OpenAI
14
  from env.environment import SQLEnv
15
  from env.models import Action
16
 
17
+
18
  def run_task(env: SQLEnv, task_id: int) -> float:
19
  client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
20
+ obs = env.reset(task_id=task_id)
21
+
22
  messages = [
23
+ {
24
+ "role": "system",
25
+ "content": (
26
+ "You are an expert SQL DBA. You rewrite SQL queries "
27
+ "to be correct, optimized, and performant."
28
+ ),
29
+ }
30
  ]
31
+
32
+ prompt = f"""Task #{obs.task_id}
 
33
  Original Query: {obs.query}
34
  Database Schema Context: {obs.schema_context}
35
  Hint: {obs.hint}
36
 
37
+ Please provide the optimized query. Output ONLY the raw SQL query, no markdown formatting, no explanation."""
38
+
 
39
  messages.append({"role": "user", "content": prompt.strip()})
40
+
41
  try:
42
  response = client.chat.completions.create(
43
  model="gpt-3.5-turbo",
44
  messages=messages,
45
+ temperature=0.0,
46
  )
47
  rewritten_query = response.choices[0].message.content.strip()
48
  if rewritten_query.startswith("```sql"):
 
53
  except Exception as e:
54
  print(f"Error calling OpenAI API: {e}")
55
  rewritten_query = obs.query
56
+
57
  action = Action(
58
  rewritten_query=rewritten_query,
59
  explanation="Baseline inference using LLM",
60
+ is_done=True,
61
  )
62
+
63
+ result_obs = env.step(action)
64
+ return result_obs.reward
65
+
66
 
67
  def run_all_tasks():
68
  if not os.environ.get("OPENAI_API_KEY"):
69
  raise ValueError("OPENAI_API_KEY environment variable is required.")
70
+
71
  env = SQLEnv()
72
  scores = {}
73
  for task_id in [1, 2, 3]:
74
  print(f"Running baseline for Task {task_id}...")
75
  score = run_task(env, task_id)
76
  scores[task_id] = score
77
+ print(f"Task {task_id} Score: {score}")
78
+
79
  return scores
80
 
81
+
82
  if __name__ == "__main__":
83
  try:
84
  scores = run_all_tasks()
env/environment.py CHANGED
@@ -1,9 +1,19 @@
1
- from typing import Tuple, Dict, Any, List
 
 
 
 
 
2
  from .models import Observation, Action, Reward
3
  from .tasks import TASKS, grade_action, get_task
4
  from .reward import compute_reward
5
 
6
- class SQLEnv:
 
 
 
 
 
7
  def __init__(self):
8
  self.current_task_id = None
9
  self.task = None
@@ -13,12 +23,19 @@ class SQLEnv:
13
  self.cumulative_score = 0.0
14
  self.previous_grader_score = 0.0
15
  self.final_grader_score = 0.0
16
-
17
- def reset(self, task_id: int) -> Observation:
 
 
 
 
 
 
 
18
  task = get_task(task_id)
19
  if not task:
20
  raise ValueError(f"Task {task_id} not found.")
21
-
22
  self.current_task_id = task_id
23
  self.task = task
24
  self.step_number = 1
@@ -27,80 +44,88 @@ class SQLEnv:
27
  self.cumulative_score = 0.0
28
  self.previous_grader_score = 0.0
29
  self.final_grader_score = 0.0
30
-
 
 
 
 
31
  obs = Observation(
32
  task_id=self.current_task_id,
33
  query=self.task["initial_query"],
34
  schema_context=self.task["schema_context"],
35
  hint=self.task["hint"],
36
  step_number=self.step_number,
37
- max_steps=self.max_steps
 
 
38
  )
39
  self.history.append({"step": 0, "type": "reset", "observation": obs.model_dump()})
40
  return obs
41
-
42
- def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict[str, Any]]:
 
 
 
 
 
43
  if not self.task:
44
  raise RuntimeError("Environment not initialized. Call reset() first.")
45
-
46
- grader_score, breakdown, feedback = grade_action(self.current_task_id, action.rewritten_query)
 
 
47
  action_valid = len(action.rewritten_query.strip()) > 0
48
-
49
  done = action.is_done or self.step_number >= self.max_steps
50
-
51
  step_reward = compute_reward(
52
  grader_score=grader_score,
53
  previous_score=self.previous_grader_score,
54
  step_number=self.step_number,
55
  max_steps=self.max_steps,
56
  is_done=done,
57
- action_valid=action_valid
58
  )
59
-
60
  self.cumulative_score += step_reward
61
  self.previous_grader_score = grader_score
62
-
63
- reward = Reward(
64
- score=step_reward,
65
- breakdown=breakdown,
66
- feedback=feedback
67
- )
68
-
 
 
 
 
 
 
69
  obs = Observation(
70
  task_id=self.current_task_id,
71
  query=action.rewritten_query,
72
  schema_context=self.task["schema_context"],
73
- hint=self.task["hint"],
74
  step_number=self.step_number + 1,
75
- max_steps=self.max_steps
 
 
 
76
  )
77
-
78
- info = {
79
- "cumulative_score": self.cumulative_score,
80
- "grader_score": grader_score
81
- }
82
-
83
- if done:
84
- self.final_grader_score = grader_score
85
-
86
  self.history.append({
87
  "step": self.step_number,
88
  "type": "step",
89
  "action": action.model_dump(),
90
- "reward": reward.model_dump(),
91
  "done": done,
92
- "info": info
93
  })
94
-
95
  self.step_number += 1
96
- return obs, reward, done, info
97
-
98
- def state(self) -> Dict[str, Any]:
99
- return {
100
- "current_task_id": self.current_task_id,
101
- "step_number": self.step_number,
102
- "max_steps": self.max_steps,
103
- "cumulative_score": self.cumulative_score,
104
- "final_grader_score": self.final_grader_score,
105
- "history": self.history
106
- }
 
1
+ from typing import Optional, Dict, Any
2
+ from uuid import uuid4
3
+
4
+ from openenv.core.env_server.interfaces import Environment
5
+ from openenv.core.env_server.types import State
6
+
7
  from .models import Observation, Action, Reward
8
  from .tasks import TASKS, grade_action, get_task
9
  from .reward import compute_reward
10
 
11
+
12
+ class SQLEnv(Environment):
13
+ """SQL Query Optimizer Environment following the OpenEnv interface."""
14
+
15
+ SUPPORTS_CONCURRENT_SESSIONS: bool = True
16
+
17
  def __init__(self):
18
  self.current_task_id = None
19
  self.task = None
 
23
  self.cumulative_score = 0.0
24
  self.previous_grader_score = 0.0
25
  self.final_grader_score = 0.0
26
+ self._state = State(episode_id=str(uuid4()), step_count=0)
27
+
28
+ def reset(
29
+ self,
30
+ seed: Optional[int] = None,
31
+ episode_id: Optional[str] = None,
32
+ task_id: int = 1,
33
+ **kwargs: Any,
34
+ ) -> Observation:
35
  task = get_task(task_id)
36
  if not task:
37
  raise ValueError(f"Task {task_id} not found.")
38
+
39
  self.current_task_id = task_id
40
  self.task = task
41
  self.step_number = 1
 
44
  self.cumulative_score = 0.0
45
  self.previous_grader_score = 0.0
46
  self.final_grader_score = 0.0
47
+ self._state = State(
48
+ episode_id=episode_id or str(uuid4()),
49
+ step_count=0,
50
+ )
51
+
52
  obs = Observation(
53
  task_id=self.current_task_id,
54
  query=self.task["initial_query"],
55
  schema_context=self.task["schema_context"],
56
  hint=self.task["hint"],
57
  step_number=self.step_number,
58
+ max_steps=self.max_steps,
59
+ reward=0.0,
60
+ done=False,
61
  )
62
  self.history.append({"step": 0, "type": "reset", "observation": obs.model_dump()})
63
  return obs
64
+
65
+ def step(
66
+ self,
67
+ action: Action,
68
+ timeout_s: Optional[float] = None,
69
+ **kwargs: Any,
70
+ ) -> Observation:
71
  if not self.task:
72
  raise RuntimeError("Environment not initialized. Call reset() first.")
73
+
74
+ grader_score, breakdown, feedback = grade_action(
75
+ self.current_task_id, action.rewritten_query
76
+ )
77
  action_valid = len(action.rewritten_query.strip()) > 0
78
+
79
  done = action.is_done or self.step_number >= self.max_steps
80
+
81
  step_reward = compute_reward(
82
  grader_score=grader_score,
83
  previous_score=self.previous_grader_score,
84
  step_number=self.step_number,
85
  max_steps=self.max_steps,
86
  is_done=done,
87
+ action_valid=action_valid,
88
  )
89
+
90
  self.cumulative_score += step_reward
91
  self.previous_grader_score = grader_score
92
+
93
+ info = {
94
+ "cumulative_score": self.cumulative_score,
95
+ "grader_score": grader_score,
96
+ "breakdown": breakdown,
97
+ "feedback": feedback,
98
+ }
99
+
100
+ if done:
101
+ self.final_grader_score = grader_score
102
+
103
+ self._state.step_count += 1
104
+
105
  obs = Observation(
106
  task_id=self.current_task_id,
107
  query=action.rewritten_query,
108
  schema_context=self.task["schema_context"],
109
+ hint=self.task["hint"],
110
  step_number=self.step_number + 1,
111
+ max_steps=self.max_steps,
112
+ reward=step_reward,
113
+ done=done,
114
+ metadata=info,
115
  )
116
+
 
 
 
 
 
 
 
 
117
  self.history.append({
118
  "step": self.step_number,
119
  "type": "step",
120
  "action": action.model_dump(),
121
+ "reward": step_reward,
122
  "done": done,
123
+ "info": info,
124
  })
125
+
126
  self.step_number += 1
127
+ return obs
128
+
129
+ @property
130
+ def state(self) -> State:
131
+ return self._state
 
 
 
 
 
 
env/models.py CHANGED
@@ -1,20 +1,28 @@
1
- from typing import Optional, Dict
2
- from pydantic import BaseModel, Field
 
3
 
4
- class Observation(BaseModel):
5
- task_id: int = Field(description="The ID of the task to perform.")
6
- query: str = Field(description="The SQL query to review and optimize.")
7
- schema_context: str = Field(description="The database schema context for the query, such as CREATE TABLE statements.")
8
- hint: Optional[str] = Field(default=None, description="An optional natural-language hint or description of the problem.")
9
- step_number: int = Field(description="The current step number in the episode (1-indexed).")
10
- max_steps: int = Field(description="The maximum allowed steps for this task.")
11
 
12
- class Action(BaseModel):
13
- rewritten_query: str = Field(description="The rewritten, optimized SQL query.")
14
- explanation: str = Field(description="A brief explanation of the changes made and why they improve the query.")
15
- is_done: bool = Field(description="Set to true if you are finished and want to submit the query for final scoring.")
 
 
 
16
 
17
- class Reward(BaseModel):
18
- score: float = Field(description="The overall score for the episode (0.0 to 1.0).")
19
- breakdown: Dict[str, float] = Field(default_factory=dict, description="A breakdown of the score by sub-criteria.")
20
- feedback: str = Field(description="Specific feedback on the rewritten query or action taken.")
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Dict, Any
2
+ from pydantic import Field
3
+ from openenv.core.env_server.types import Action as BaseAction, Observation as BaseObservation
4
 
 
 
 
 
 
 
 
5
 
6
+ class Observation(BaseObservation):
7
+ task_id: int = Field(default=0, description="The ID of the task to perform.")
8
+ query: str = Field(default="", description="The SQL query to review and optimize.")
9
+ schema_context: str = Field(default="", description="The database schema context.")
10
+ hint: Optional[str] = Field(default=None, description="An optional natural-language hint.")
11
+ step_number: int = Field(default=0, description="The current step number in the episode.")
12
+ max_steps: int = Field(default=0, description="The maximum allowed steps for this task.")
13
 
14
+
15
+ class Action(BaseAction):
16
+ rewritten_query: str = Field(default="", description="The rewritten, optimized SQL query.")
17
+ explanation: str = Field(default="", description="A brief explanation of the changes.")
18
+ is_done: bool = Field(default=False, description="Set to true to submit for final scoring.")
19
+
20
+
21
+ class Reward:
22
+ def __init__(self, score: float = 0.0, breakdown: Dict[str, float] = None, feedback: str = ""):
23
+ self.score = score
24
+ self.breakdown = breakdown or {}
25
+ self.feedback = feedback
26
+
27
+ def model_dump(self):
28
+ return {"score": self.score, "breakdown": self.breakdown, "feedback": self.feedback}
inference.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Baseline inference script for the SQL Query Optimizer OpenEnv.
3
+
4
+ Uses the OpenAI API client to run a model against the environment
5
+ and produce reproducible baseline scores on all 3 tasks.
6
+
7
+ Usage:
8
+ export OPENAI_API_KEY=sk-...
9
+ python inference.py
10
+ """
11
+
12
+ import os
13
+ from openai import OpenAI
14
+ from env.environment import SQLEnv
15
+ from env.models import Action
16
+
17
+
18
+ def run_task(env: SQLEnv, task_id: int) -> float:
19
+ client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
20
+ obs = env.reset(task_id=task_id)
21
+
22
+ messages = [
23
+ {
24
+ "role": "system",
25
+ "content": (
26
+ "You are an expert SQL DBA. You rewrite SQL queries "
27
+ "to be correct, optimized, and performant."
28
+ ),
29
+ }
30
+ ]
31
+
32
+ prompt = f"""Task #{obs.task_id}
33
+ Original Query: {obs.query}
34
+ Database Schema Context: {obs.schema_context}
35
+ Hint: {obs.hint}
36
+
37
+ Please provide the optimized query. Output ONLY the raw SQL query, no markdown formatting, no explanation."""
38
+
39
+ messages.append({"role": "user", "content": prompt.strip()})
40
+
41
+ try:
42
+ response = client.chat.completions.create(
43
+ model="gpt-3.5-turbo",
44
+ messages=messages,
45
+ temperature=0.0,
46
+ )
47
+ rewritten_query = response.choices[0].message.content.strip()
48
+ if rewritten_query.startswith("```sql"):
49
+ rewritten_query = rewritten_query[6:]
50
+ if rewritten_query.endswith("```"):
51
+ rewritten_query = rewritten_query[:-3]
52
+ rewritten_query = rewritten_query.strip()
53
+ except Exception as e:
54
+ print(f"Error calling OpenAI API: {e}")
55
+ rewritten_query = obs.query
56
+
57
+ action = Action(
58
+ rewritten_query=rewritten_query,
59
+ explanation="Baseline inference using LLM",
60
+ is_done=True,
61
+ )
62
+
63
+ result_obs = env.step(action)
64
+ return result_obs.reward
65
+
66
+
67
+ def run_all_tasks():
68
+ if not os.environ.get("OPENAI_API_KEY"):
69
+ raise ValueError("OPENAI_API_KEY environment variable is required.")
70
+
71
+ env = SQLEnv()
72
+ scores = {}
73
+ for task_id in [1, 2, 3]:
74
+ print(f"Running baseline for Task {task_id}...")
75
+ score = run_task(env, task_id)
76
+ scores[task_id] = score
77
+ print(f"Task {task_id} Score: {score}")
78
+
79
+ return scores
80
+
81
+
82
+ if __name__ == "__main__":
83
+ try:
84
+ scores = run_all_tasks()
85
+ print("\nBaseline Evaluation Results:")
86
+ for t, s in scores.items():
87
+ print(f"Task {t}: {s}/1.0")
88
+ except Exception as e:
89
+ print(f"Baseline Evaluation Failed: {e}")
models.py CHANGED
@@ -1,20 +1,28 @@
1
- from typing import Optional, Dict
2
- from pydantic import BaseModel, Field
 
3
 
4
- class Observation(BaseModel):
5
- task_id: int = Field(description="The ID of the task to perform.")
6
- query: str = Field(description="The SQL query to review and optimize.")
7
- schema_context: str = Field(description="The database schema context for the query, such as CREATE TABLE statements.")
8
- hint: Optional[str] = Field(default=None, description="An optional natural-language hint or description of the problem.")
9
- step_number: int = Field(description="The current step number in the episode (1-indexed).")
10
- max_steps: int = Field(description="The maximum allowed steps for this task.")
11
 
12
- class Action(BaseModel):
13
- rewritten_query: str = Field(description="The rewritten, optimized SQL query.")
14
- explanation: str = Field(description="A brief explanation of the changes made and why they improve the query.")
15
- is_done: bool = Field(description="Set to true if you are finished and want to submit the query for final scoring.")
 
 
 
16
 
17
- class Reward(BaseModel):
18
- score: float = Field(description="The overall score for the episode (0.0 to 1.0).")
19
- breakdown: Dict[str, float] = Field(default_factory=dict, description="A breakdown of the score by sub-criteria.")
20
- feedback: str = Field(description="Specific feedback on the rewritten query or action taken.")
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Dict, Any
2
+ from pydantic import Field
3
+ from openenv.core.env_server.types import Action as BaseAction, Observation as BaseObservation
4
 
 
 
 
 
 
 
 
5
 
6
+ class Observation(BaseObservation):
7
+ task_id: int = Field(default=0, description="The ID of the task to perform.")
8
+ query: str = Field(default="", description="The SQL query to review and optimize.")
9
+ schema_context: str = Field(default="", description="The database schema context.")
10
+ hint: Optional[str] = Field(default=None, description="An optional natural-language hint.")
11
+ step_number: int = Field(default=0, description="The current step number in the episode.")
12
+ max_steps: int = Field(default=0, description="The maximum allowed steps for this task.")
13
 
14
+
15
+ class Action(BaseAction):
16
+ rewritten_query: str = Field(default="", description="The rewritten, optimized SQL query.")
17
+ explanation: str = Field(default="", description="A brief explanation of the changes.")
18
+ is_done: bool = Field(default=False, description="Set to true to submit for final scoring.")
19
+
20
+
21
+ class Reward:
22
+ def __init__(self, score: float = 0.0, breakdown: Dict[str, float] = None, feedback: str = ""):
23
+ self.score = score
24
+ self.breakdown = breakdown or {}
25
+ self.feedback = feedback
26
+
27
+ def model_dump(self):
28
+ return {"score": self.score, "breakdown": self.breakdown, "feedback": self.feedback}
server/app.py CHANGED
@@ -12,26 +12,39 @@ app = create_app(
12
  env=SQLEnv,
13
  action_cls=Action,
14
  observation_cls=Observation,
15
- env_name="sql-query-optimizer"
16
  )
17
 
 
18
  @app.get("/tasks")
19
  async def get_tasks():
20
  action_schema = Action.model_json_schema()
21
  task_list = [{"id": k, **v} for k, v in TASKS.items()]
22
  return {
23
  "tasks": task_list,
24
- "action_schema": action_schema
25
  }
26
 
 
27
  class BaselineResponse(BaseModel):
28
  scores: Dict[int, float]
29
 
 
30
  @app.post("/baseline", response_model=BaselineResponse)
31
  async def run_baseline():
32
  import baseline
 
33
  try:
34
  scores = baseline.run_all_tasks()
35
  return BaselineResponse(scores=scores)
36
  except Exception as e:
37
  raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
12
  env=SQLEnv,
13
  action_cls=Action,
14
  observation_cls=Observation,
15
+ env_name="sql-query-optimizer",
16
  )
17
 
18
+
19
  @app.get("/tasks")
20
  async def get_tasks():
21
  action_schema = Action.model_json_schema()
22
  task_list = [{"id": k, **v} for k, v in TASKS.items()]
23
  return {
24
  "tasks": task_list,
25
+ "action_schema": action_schema,
26
  }
27
 
28
+
29
  class BaselineResponse(BaseModel):
30
  scores: Dict[int, float]
31
 
32
+
33
  @app.post("/baseline", response_model=BaselineResponse)
34
  async def run_baseline():
35
  import baseline
36
+
37
  try:
38
  scores = baseline.run_all_tasks()
39
  return BaselineResponse(scores=scores)
40
  except Exception as e:
41
  raise HTTPException(status_code=500, detail=str(e))
42
+
43
+
44
+ def main(host: str = "0.0.0.0", port: int = 7860):
45
+ import uvicorn
46
+ uvicorn.run(app, host=host, port=port)
47
+
48
+
49
+ if __name__ == "__main__":
50
+ main()