Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- baseline.py +37 -19
- env/environment.py +72 -47
- env/models.py +25 -17
- inference.py +89 -0
- models.py +25 -17
- server/app.py +15 -2
baseline.py
CHANGED
|
@@ -1,32 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
from openai import OpenAI
|
| 3 |
from env.environment import SQLEnv
|
| 4 |
from env.models import Action
|
| 5 |
|
|
|
|
| 6 |
def run_task(env: SQLEnv, task_id: int) -> float:
|
| 7 |
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
| 8 |
-
obs = env.reset(task_id)
|
| 9 |
-
|
| 10 |
messages = [
|
| 11 |
-
{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
]
|
| 13 |
-
|
| 14 |
-
prompt = f"""
|
| 15 |
-
Task # {obs.task_id}
|
| 16 |
Original Query: {obs.query}
|
| 17 |
Database Schema Context: {obs.schema_context}
|
| 18 |
Hint: {obs.hint}
|
| 19 |
|
| 20 |
-
Please provide the optimized query. Output ONLY the raw SQL query, no markdown formatting, no explanation.
|
| 21 |
-
|
| 22 |
-
|
| 23 |
messages.append({"role": "user", "content": prompt.strip()})
|
| 24 |
-
|
| 25 |
try:
|
| 26 |
response = client.chat.completions.create(
|
| 27 |
model="gpt-3.5-turbo",
|
| 28 |
messages=messages,
|
| 29 |
-
temperature=0.0
|
| 30 |
)
|
| 31 |
rewritten_query = response.choices[0].message.content.strip()
|
| 32 |
if rewritten_query.startswith("```sql"):
|
|
@@ -37,30 +53,32 @@ Please provide the optimized query. Output ONLY the raw SQL query, no markdown f
|
|
| 37 |
except Exception as e:
|
| 38 |
print(f"Error calling OpenAI API: {e}")
|
| 39 |
rewritten_query = obs.query
|
| 40 |
-
|
| 41 |
action = Action(
|
| 42 |
rewritten_query=rewritten_query,
|
| 43 |
explanation="Baseline inference using LLM",
|
| 44 |
-
is_done=True
|
| 45 |
)
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
return
|
|
|
|
| 49 |
|
| 50 |
def run_all_tasks():
|
| 51 |
if not os.environ.get("OPENAI_API_KEY"):
|
| 52 |
raise ValueError("OPENAI_API_KEY environment variable is required.")
|
| 53 |
-
|
| 54 |
env = SQLEnv()
|
| 55 |
scores = {}
|
| 56 |
for task_id in [1, 2, 3]:
|
| 57 |
print(f"Running baseline for Task {task_id}...")
|
| 58 |
score = run_task(env, task_id)
|
| 59 |
scores[task_id] = score
|
| 60 |
-
print(f"Task {task_id}
|
| 61 |
-
|
| 62 |
return scores
|
| 63 |
|
|
|
|
| 64 |
if __name__ == "__main__":
|
| 65 |
try:
|
| 66 |
scores = run_all_tasks()
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Baseline inference script for the SQL Query Optimizer OpenEnv.
|
| 3 |
+
|
| 4 |
+
Uses the OpenAI API client to run a model against the environment
|
| 5 |
+
and produce reproducible baseline scores on all 3 tasks.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
export OPENAI_API_KEY=sk-...
|
| 9 |
+
python baseline.py
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
import os
|
| 13 |
from openai import OpenAI
|
| 14 |
from env.environment import SQLEnv
|
| 15 |
from env.models import Action
|
| 16 |
|
| 17 |
+
|
| 18 |
def run_task(env: SQLEnv, task_id: int) -> float:
|
| 19 |
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
| 20 |
+
obs = env.reset(task_id=task_id)
|
| 21 |
+
|
| 22 |
messages = [
|
| 23 |
+
{
|
| 24 |
+
"role": "system",
|
| 25 |
+
"content": (
|
| 26 |
+
"You are an expert SQL DBA. You rewrite SQL queries "
|
| 27 |
+
"to be correct, optimized, and performant."
|
| 28 |
+
),
|
| 29 |
+
}
|
| 30 |
]
|
| 31 |
+
|
| 32 |
+
prompt = f"""Task #{obs.task_id}
|
|
|
|
| 33 |
Original Query: {obs.query}
|
| 34 |
Database Schema Context: {obs.schema_context}
|
| 35 |
Hint: {obs.hint}
|
| 36 |
|
| 37 |
+
Please provide the optimized query. Output ONLY the raw SQL query, no markdown formatting, no explanation."""
|
| 38 |
+
|
|
|
|
| 39 |
messages.append({"role": "user", "content": prompt.strip()})
|
| 40 |
+
|
| 41 |
try:
|
| 42 |
response = client.chat.completions.create(
|
| 43 |
model="gpt-3.5-turbo",
|
| 44 |
messages=messages,
|
| 45 |
+
temperature=0.0,
|
| 46 |
)
|
| 47 |
rewritten_query = response.choices[0].message.content.strip()
|
| 48 |
if rewritten_query.startswith("```sql"):
|
|
|
|
| 53 |
except Exception as e:
|
| 54 |
print(f"Error calling OpenAI API: {e}")
|
| 55 |
rewritten_query = obs.query
|
| 56 |
+
|
| 57 |
action = Action(
|
| 58 |
rewritten_query=rewritten_query,
|
| 59 |
explanation="Baseline inference using LLM",
|
| 60 |
+
is_done=True,
|
| 61 |
)
|
| 62 |
+
|
| 63 |
+
result_obs = env.step(action)
|
| 64 |
+
return result_obs.reward
|
| 65 |
+
|
| 66 |
|
| 67 |
def run_all_tasks():
|
| 68 |
if not os.environ.get("OPENAI_API_KEY"):
|
| 69 |
raise ValueError("OPENAI_API_KEY environment variable is required.")
|
| 70 |
+
|
| 71 |
env = SQLEnv()
|
| 72 |
scores = {}
|
| 73 |
for task_id in [1, 2, 3]:
|
| 74 |
print(f"Running baseline for Task {task_id}...")
|
| 75 |
score = run_task(env, task_id)
|
| 76 |
scores[task_id] = score
|
| 77 |
+
print(f"Task {task_id} Score: {score}")
|
| 78 |
+
|
| 79 |
return scores
|
| 80 |
|
| 81 |
+
|
| 82 |
if __name__ == "__main__":
|
| 83 |
try:
|
| 84 |
scores = run_all_tasks()
|
env/environment.py
CHANGED
|
@@ -1,9 +1,19 @@
|
|
| 1 |
-
from typing import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from .models import Observation, Action, Reward
|
| 3 |
from .tasks import TASKS, grade_action, get_task
|
| 4 |
from .reward import compute_reward
|
| 5 |
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
def __init__(self):
|
| 8 |
self.current_task_id = None
|
| 9 |
self.task = None
|
|
@@ -13,12 +23,19 @@ class SQLEnv:
|
|
| 13 |
self.cumulative_score = 0.0
|
| 14 |
self.previous_grader_score = 0.0
|
| 15 |
self.final_grader_score = 0.0
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
task = get_task(task_id)
|
| 19 |
if not task:
|
| 20 |
raise ValueError(f"Task {task_id} not found.")
|
| 21 |
-
|
| 22 |
self.current_task_id = task_id
|
| 23 |
self.task = task
|
| 24 |
self.step_number = 1
|
|
@@ -27,80 +44,88 @@ class SQLEnv:
|
|
| 27 |
self.cumulative_score = 0.0
|
| 28 |
self.previous_grader_score = 0.0
|
| 29 |
self.final_grader_score = 0.0
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
obs = Observation(
|
| 32 |
task_id=self.current_task_id,
|
| 33 |
query=self.task["initial_query"],
|
| 34 |
schema_context=self.task["schema_context"],
|
| 35 |
hint=self.task["hint"],
|
| 36 |
step_number=self.step_number,
|
| 37 |
-
max_steps=self.max_steps
|
|
|
|
|
|
|
| 38 |
)
|
| 39 |
self.history.append({"step": 0, "type": "reset", "observation": obs.model_dump()})
|
| 40 |
return obs
|
| 41 |
-
|
| 42 |
-
def step(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
if not self.task:
|
| 44 |
raise RuntimeError("Environment not initialized. Call reset() first.")
|
| 45 |
-
|
| 46 |
-
grader_score, breakdown, feedback = grade_action(
|
|
|
|
|
|
|
| 47 |
action_valid = len(action.rewritten_query.strip()) > 0
|
| 48 |
-
|
| 49 |
done = action.is_done or self.step_number >= self.max_steps
|
| 50 |
-
|
| 51 |
step_reward = compute_reward(
|
| 52 |
grader_score=grader_score,
|
| 53 |
previous_score=self.previous_grader_score,
|
| 54 |
step_number=self.step_number,
|
| 55 |
max_steps=self.max_steps,
|
| 56 |
is_done=done,
|
| 57 |
-
action_valid=action_valid
|
| 58 |
)
|
| 59 |
-
|
| 60 |
self.cumulative_score += step_reward
|
| 61 |
self.previous_grader_score = grader_score
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
obs = Observation(
|
| 70 |
task_id=self.current_task_id,
|
| 71 |
query=action.rewritten_query,
|
| 72 |
schema_context=self.task["schema_context"],
|
| 73 |
-
hint=self.task["hint"],
|
| 74 |
step_number=self.step_number + 1,
|
| 75 |
-
max_steps=self.max_steps
|
|
|
|
|
|
|
|
|
|
| 76 |
)
|
| 77 |
-
|
| 78 |
-
info = {
|
| 79 |
-
"cumulative_score": self.cumulative_score,
|
| 80 |
-
"grader_score": grader_score
|
| 81 |
-
}
|
| 82 |
-
|
| 83 |
-
if done:
|
| 84 |
-
self.final_grader_score = grader_score
|
| 85 |
-
|
| 86 |
self.history.append({
|
| 87 |
"step": self.step_number,
|
| 88 |
"type": "step",
|
| 89 |
"action": action.model_dump(),
|
| 90 |
-
"reward":
|
| 91 |
"done": done,
|
| 92 |
-
"info": info
|
| 93 |
})
|
| 94 |
-
|
| 95 |
self.step_number += 1
|
| 96 |
-
return obs
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
"step_number": self.step_number,
|
| 102 |
-
"max_steps": self.max_steps,
|
| 103 |
-
"cumulative_score": self.cumulative_score,
|
| 104 |
-
"final_grader_score": self.final_grader_score,
|
| 105 |
-
"history": self.history
|
| 106 |
-
}
|
|
|
|
| 1 |
+
from typing import Optional, Dict, Any
|
| 2 |
+
from uuid import uuid4
|
| 3 |
+
|
| 4 |
+
from openenv.core.env_server.interfaces import Environment
|
| 5 |
+
from openenv.core.env_server.types import State
|
| 6 |
+
|
| 7 |
from .models import Observation, Action, Reward
|
| 8 |
from .tasks import TASKS, grade_action, get_task
|
| 9 |
from .reward import compute_reward
|
| 10 |
|
| 11 |
+
|
| 12 |
+
class SQLEnv(Environment):
|
| 13 |
+
"""SQL Query Optimizer Environment following the OpenEnv interface."""
|
| 14 |
+
|
| 15 |
+
SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
| 16 |
+
|
| 17 |
def __init__(self):
|
| 18 |
self.current_task_id = None
|
| 19 |
self.task = None
|
|
|
|
| 23 |
self.cumulative_score = 0.0
|
| 24 |
self.previous_grader_score = 0.0
|
| 25 |
self.final_grader_score = 0.0
|
| 26 |
+
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 27 |
+
|
| 28 |
+
def reset(
|
| 29 |
+
self,
|
| 30 |
+
seed: Optional[int] = None,
|
| 31 |
+
episode_id: Optional[str] = None,
|
| 32 |
+
task_id: int = 1,
|
| 33 |
+
**kwargs: Any,
|
| 34 |
+
) -> Observation:
|
| 35 |
task = get_task(task_id)
|
| 36 |
if not task:
|
| 37 |
raise ValueError(f"Task {task_id} not found.")
|
| 38 |
+
|
| 39 |
self.current_task_id = task_id
|
| 40 |
self.task = task
|
| 41 |
self.step_number = 1
|
|
|
|
| 44 |
self.cumulative_score = 0.0
|
| 45 |
self.previous_grader_score = 0.0
|
| 46 |
self.final_grader_score = 0.0
|
| 47 |
+
self._state = State(
|
| 48 |
+
episode_id=episode_id or str(uuid4()),
|
| 49 |
+
step_count=0,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
obs = Observation(
|
| 53 |
task_id=self.current_task_id,
|
| 54 |
query=self.task["initial_query"],
|
| 55 |
schema_context=self.task["schema_context"],
|
| 56 |
hint=self.task["hint"],
|
| 57 |
step_number=self.step_number,
|
| 58 |
+
max_steps=self.max_steps,
|
| 59 |
+
reward=0.0,
|
| 60 |
+
done=False,
|
| 61 |
)
|
| 62 |
self.history.append({"step": 0, "type": "reset", "observation": obs.model_dump()})
|
| 63 |
return obs
|
| 64 |
+
|
| 65 |
+
def step(
|
| 66 |
+
self,
|
| 67 |
+
action: Action,
|
| 68 |
+
timeout_s: Optional[float] = None,
|
| 69 |
+
**kwargs: Any,
|
| 70 |
+
) -> Observation:
|
| 71 |
if not self.task:
|
| 72 |
raise RuntimeError("Environment not initialized. Call reset() first.")
|
| 73 |
+
|
| 74 |
+
grader_score, breakdown, feedback = grade_action(
|
| 75 |
+
self.current_task_id, action.rewritten_query
|
| 76 |
+
)
|
| 77 |
action_valid = len(action.rewritten_query.strip()) > 0
|
| 78 |
+
|
| 79 |
done = action.is_done or self.step_number >= self.max_steps
|
| 80 |
+
|
| 81 |
step_reward = compute_reward(
|
| 82 |
grader_score=grader_score,
|
| 83 |
previous_score=self.previous_grader_score,
|
| 84 |
step_number=self.step_number,
|
| 85 |
max_steps=self.max_steps,
|
| 86 |
is_done=done,
|
| 87 |
+
action_valid=action_valid,
|
| 88 |
)
|
| 89 |
+
|
| 90 |
self.cumulative_score += step_reward
|
| 91 |
self.previous_grader_score = grader_score
|
| 92 |
+
|
| 93 |
+
info = {
|
| 94 |
+
"cumulative_score": self.cumulative_score,
|
| 95 |
+
"grader_score": grader_score,
|
| 96 |
+
"breakdown": breakdown,
|
| 97 |
+
"feedback": feedback,
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
if done:
|
| 101 |
+
self.final_grader_score = grader_score
|
| 102 |
+
|
| 103 |
+
self._state.step_count += 1
|
| 104 |
+
|
| 105 |
obs = Observation(
|
| 106 |
task_id=self.current_task_id,
|
| 107 |
query=action.rewritten_query,
|
| 108 |
schema_context=self.task["schema_context"],
|
| 109 |
+
hint=self.task["hint"],
|
| 110 |
step_number=self.step_number + 1,
|
| 111 |
+
max_steps=self.max_steps,
|
| 112 |
+
reward=step_reward,
|
| 113 |
+
done=done,
|
| 114 |
+
metadata=info,
|
| 115 |
)
|
| 116 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
self.history.append({
|
| 118 |
"step": self.step_number,
|
| 119 |
"type": "step",
|
| 120 |
"action": action.model_dump(),
|
| 121 |
+
"reward": step_reward,
|
| 122 |
"done": done,
|
| 123 |
+
"info": info,
|
| 124 |
})
|
| 125 |
+
|
| 126 |
self.step_number += 1
|
| 127 |
+
return obs
|
| 128 |
+
|
| 129 |
+
@property
|
| 130 |
+
def state(self) -> State:
|
| 131 |
+
return self._state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
env/models.py
CHANGED
|
@@ -1,20 +1,28 @@
|
|
| 1 |
-
from typing import Optional, Dict
|
| 2 |
-
from pydantic import
|
|
|
|
| 3 |
|
| 4 |
-
class Observation(BaseModel):
|
| 5 |
-
task_id: int = Field(description="The ID of the task to perform.")
|
| 6 |
-
query: str = Field(description="The SQL query to review and optimize.")
|
| 7 |
-
schema_context: str = Field(description="The database schema context for the query, such as CREATE TABLE statements.")
|
| 8 |
-
hint: Optional[str] = Field(default=None, description="An optional natural-language hint or description of the problem.")
|
| 9 |
-
step_number: int = Field(description="The current step number in the episode (1-indexed).")
|
| 10 |
-
max_steps: int = Field(description="The maximum allowed steps for this task.")
|
| 11 |
|
| 12 |
-
class
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Dict, Any
|
| 2 |
+
from pydantic import Field
|
| 3 |
+
from openenv.core.env_server.types import Action as BaseAction, Observation as BaseObservation
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
+
class Observation(BaseObservation):
|
| 7 |
+
task_id: int = Field(default=0, description="The ID of the task to perform.")
|
| 8 |
+
query: str = Field(default="", description="The SQL query to review and optimize.")
|
| 9 |
+
schema_context: str = Field(default="", description="The database schema context.")
|
| 10 |
+
hint: Optional[str] = Field(default=None, description="An optional natural-language hint.")
|
| 11 |
+
step_number: int = Field(default=0, description="The current step number in the episode.")
|
| 12 |
+
max_steps: int = Field(default=0, description="The maximum allowed steps for this task.")
|
| 13 |
|
| 14 |
+
|
| 15 |
+
class Action(BaseAction):
|
| 16 |
+
rewritten_query: str = Field(default="", description="The rewritten, optimized SQL query.")
|
| 17 |
+
explanation: str = Field(default="", description="A brief explanation of the changes.")
|
| 18 |
+
is_done: bool = Field(default=False, description="Set to true to submit for final scoring.")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class Reward:
|
| 22 |
+
def __init__(self, score: float = 0.0, breakdown: Dict[str, float] = None, feedback: str = ""):
|
| 23 |
+
self.score = score
|
| 24 |
+
self.breakdown = breakdown or {}
|
| 25 |
+
self.feedback = feedback
|
| 26 |
+
|
| 27 |
+
def model_dump(self):
|
| 28 |
+
return {"score": self.score, "breakdown": self.breakdown, "feedback": self.feedback}
|
inference.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Baseline inference script for the SQL Query Optimizer OpenEnv.
|
| 3 |
+
|
| 4 |
+
Uses the OpenAI API client to run a model against the environment
|
| 5 |
+
and produce reproducible baseline scores on all 3 tasks.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
export OPENAI_API_KEY=sk-...
|
| 9 |
+
python inference.py
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
from openai import OpenAI
|
| 14 |
+
from env.environment import SQLEnv
|
| 15 |
+
from env.models import Action
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def run_task(env: SQLEnv, task_id: int) -> float:
|
| 19 |
+
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
| 20 |
+
obs = env.reset(task_id=task_id)
|
| 21 |
+
|
| 22 |
+
messages = [
|
| 23 |
+
{
|
| 24 |
+
"role": "system",
|
| 25 |
+
"content": (
|
| 26 |
+
"You are an expert SQL DBA. You rewrite SQL queries "
|
| 27 |
+
"to be correct, optimized, and performant."
|
| 28 |
+
),
|
| 29 |
+
}
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
prompt = f"""Task #{obs.task_id}
|
| 33 |
+
Original Query: {obs.query}
|
| 34 |
+
Database Schema Context: {obs.schema_context}
|
| 35 |
+
Hint: {obs.hint}
|
| 36 |
+
|
| 37 |
+
Please provide the optimized query. Output ONLY the raw SQL query, no markdown formatting, no explanation."""
|
| 38 |
+
|
| 39 |
+
messages.append({"role": "user", "content": prompt.strip()})
|
| 40 |
+
|
| 41 |
+
try:
|
| 42 |
+
response = client.chat.completions.create(
|
| 43 |
+
model="gpt-3.5-turbo",
|
| 44 |
+
messages=messages,
|
| 45 |
+
temperature=0.0,
|
| 46 |
+
)
|
| 47 |
+
rewritten_query = response.choices[0].message.content.strip()
|
| 48 |
+
if rewritten_query.startswith("```sql"):
|
| 49 |
+
rewritten_query = rewritten_query[6:]
|
| 50 |
+
if rewritten_query.endswith("```"):
|
| 51 |
+
rewritten_query = rewritten_query[:-3]
|
| 52 |
+
rewritten_query = rewritten_query.strip()
|
| 53 |
+
except Exception as e:
|
| 54 |
+
print(f"Error calling OpenAI API: {e}")
|
| 55 |
+
rewritten_query = obs.query
|
| 56 |
+
|
| 57 |
+
action = Action(
|
| 58 |
+
rewritten_query=rewritten_query,
|
| 59 |
+
explanation="Baseline inference using LLM",
|
| 60 |
+
is_done=True,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
result_obs = env.step(action)
|
| 64 |
+
return result_obs.reward
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def run_all_tasks():
|
| 68 |
+
if not os.environ.get("OPENAI_API_KEY"):
|
| 69 |
+
raise ValueError("OPENAI_API_KEY environment variable is required.")
|
| 70 |
+
|
| 71 |
+
env = SQLEnv()
|
| 72 |
+
scores = {}
|
| 73 |
+
for task_id in [1, 2, 3]:
|
| 74 |
+
print(f"Running baseline for Task {task_id}...")
|
| 75 |
+
score = run_task(env, task_id)
|
| 76 |
+
scores[task_id] = score
|
| 77 |
+
print(f"Task {task_id} Score: {score}")
|
| 78 |
+
|
| 79 |
+
return scores
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
if __name__ == "__main__":
|
| 83 |
+
try:
|
| 84 |
+
scores = run_all_tasks()
|
| 85 |
+
print("\nBaseline Evaluation Results:")
|
| 86 |
+
for t, s in scores.items():
|
| 87 |
+
print(f"Task {t}: {s}/1.0")
|
| 88 |
+
except Exception as e:
|
| 89 |
+
print(f"Baseline Evaluation Failed: {e}")
|
models.py
CHANGED
|
@@ -1,20 +1,28 @@
|
|
| 1 |
-
from typing import Optional, Dict
|
| 2 |
-
from pydantic import
|
|
|
|
| 3 |
|
| 4 |
-
class Observation(BaseModel):
|
| 5 |
-
task_id: int = Field(description="The ID of the task to perform.")
|
| 6 |
-
query: str = Field(description="The SQL query to review and optimize.")
|
| 7 |
-
schema_context: str = Field(description="The database schema context for the query, such as CREATE TABLE statements.")
|
| 8 |
-
hint: Optional[str] = Field(default=None, description="An optional natural-language hint or description of the problem.")
|
| 9 |
-
step_number: int = Field(description="The current step number in the episode (1-indexed).")
|
| 10 |
-
max_steps: int = Field(description="The maximum allowed steps for this task.")
|
| 11 |
|
| 12 |
-
class
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Dict, Any
|
| 2 |
+
from pydantic import Field
|
| 3 |
+
from openenv.core.env_server.types import Action as BaseAction, Observation as BaseObservation
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
+
class Observation(BaseObservation):
|
| 7 |
+
task_id: int = Field(default=0, description="The ID of the task to perform.")
|
| 8 |
+
query: str = Field(default="", description="The SQL query to review and optimize.")
|
| 9 |
+
schema_context: str = Field(default="", description="The database schema context.")
|
| 10 |
+
hint: Optional[str] = Field(default=None, description="An optional natural-language hint.")
|
| 11 |
+
step_number: int = Field(default=0, description="The current step number in the episode.")
|
| 12 |
+
max_steps: int = Field(default=0, description="The maximum allowed steps for this task.")
|
| 13 |
|
| 14 |
+
|
| 15 |
+
class Action(BaseAction):
|
| 16 |
+
rewritten_query: str = Field(default="", description="The rewritten, optimized SQL query.")
|
| 17 |
+
explanation: str = Field(default="", description="A brief explanation of the changes.")
|
| 18 |
+
is_done: bool = Field(default=False, description="Set to true to submit for final scoring.")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class Reward:
|
| 22 |
+
def __init__(self, score: float = 0.0, breakdown: Dict[str, float] = None, feedback: str = ""):
|
| 23 |
+
self.score = score
|
| 24 |
+
self.breakdown = breakdown or {}
|
| 25 |
+
self.feedback = feedback
|
| 26 |
+
|
| 27 |
+
def model_dump(self):
|
| 28 |
+
return {"score": self.score, "breakdown": self.breakdown, "feedback": self.feedback}
|
server/app.py
CHANGED
|
@@ -12,26 +12,39 @@ app = create_app(
|
|
| 12 |
env=SQLEnv,
|
| 13 |
action_cls=Action,
|
| 14 |
observation_cls=Observation,
|
| 15 |
-
env_name="sql-query-optimizer"
|
| 16 |
)
|
| 17 |
|
|
|
|
| 18 |
@app.get("/tasks")
|
| 19 |
async def get_tasks():
|
| 20 |
action_schema = Action.model_json_schema()
|
| 21 |
task_list = [{"id": k, **v} for k, v in TASKS.items()]
|
| 22 |
return {
|
| 23 |
"tasks": task_list,
|
| 24 |
-
"action_schema": action_schema
|
| 25 |
}
|
| 26 |
|
|
|
|
| 27 |
class BaselineResponse(BaseModel):
|
| 28 |
scores: Dict[int, float]
|
| 29 |
|
|
|
|
| 30 |
@app.post("/baseline", response_model=BaselineResponse)
|
| 31 |
async def run_baseline():
|
| 32 |
import baseline
|
|
|
|
| 33 |
try:
|
| 34 |
scores = baseline.run_all_tasks()
|
| 35 |
return BaselineResponse(scores=scores)
|
| 36 |
except Exception as e:
|
| 37 |
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
env=SQLEnv,
|
| 13 |
action_cls=Action,
|
| 14 |
observation_cls=Observation,
|
| 15 |
+
env_name="sql-query-optimizer",
|
| 16 |
)
|
| 17 |
|
| 18 |
+
|
| 19 |
@app.get("/tasks")
|
| 20 |
async def get_tasks():
|
| 21 |
action_schema = Action.model_json_schema()
|
| 22 |
task_list = [{"id": k, **v} for k, v in TASKS.items()]
|
| 23 |
return {
|
| 24 |
"tasks": task_list,
|
| 25 |
+
"action_schema": action_schema,
|
| 26 |
}
|
| 27 |
|
| 28 |
+
|
| 29 |
class BaselineResponse(BaseModel):
|
| 30 |
scores: Dict[int, float]
|
| 31 |
|
| 32 |
+
|
| 33 |
@app.post("/baseline", response_model=BaselineResponse)
|
| 34 |
async def run_baseline():
|
| 35 |
import baseline
|
| 36 |
+
|
| 37 |
try:
|
| 38 |
scores = baseline.run_all_tasks()
|
| 39 |
return BaselineResponse(scores=scores)
|
| 40 |
except Exception as e:
|
| 41 |
raise HTTPException(status_code=500, detail=str(e))
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def main(host: str = "0.0.0.0", port: int = 7860):
|
| 45 |
+
import uvicorn
|
| 46 |
+
uvicorn.run(app, host=host, port=port)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
main()
|