Spaces:
Sleeping
Sleeping
File size: 4,937 Bytes
210535c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | """
Baseline inference script for the SQL Query Optimizer OpenEnv environment.
Usage:
python baseline.py # human-readable output
python baseline.py --json # JSON output (used by /baseline endpoint)
Requires:
OPENAI_API_KEY environment variable
The script runs gpt-4o-mini against all 3 tasks and reports grader scores.
"""
from __future__ import annotations
import argparse
import json
import os
import sys
from openai import OpenAI
# ββ import env from local package ββββββββββββββββββββββββββββββββββββββββββ
sys.path.insert(0, os.path.dirname(__file__))
from env.environment import SQLOptimizerEnv
from env.models import Action
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
MODEL = "gpt-4o-mini"
MAX_STEPS = 5
TASKS = [1, 2, 3]
SYSTEM_PROMPT = """You are a database performance engineer.
You will receive a broken or unoptimised SQL query along with table schema context.
Your job is to rewrite the query so it is correct and performant.
Respond ONLY with a JSON object with these exact keys:
{
"rewritten_query": "<your improved SQL>",
"explanation": "<brief explanation of changes>",
"is_done": true
}
Do not wrap in markdown. Output raw JSON only."""
def _build_user_message(obs_dict: dict) -> str:
return (
f"Task: {obs_dict['task_name']} ({obs_dict['task_id']} β difficulty: "
f"{obs_dict.get('difficulty', 'unknown')})\n\n"
f"Description:\n{obs_dict['task_description']}\n\n"
f"Schema:\n{obs_dict['schema_context']}\n\n"
f"Query to fix:\n{obs_dict['query']}"
+ (f"\n\nHint: {obs_dict['hint']}" if obs_dict.get("hint") else "")
)
def run_baseline(verbose: bool = True) -> dict[str, float]:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
print("ERROR: OPENAI_API_KEY is not set.", file=sys.stderr)
sys.exit(1)
client = OpenAI(api_key=api_key)
env = SQLOptimizerEnv()
results: dict[str, float] = {}
for task_id in TASKS:
obs = env.reset(task_id=task_id)
obs_dict = obs.model_dump()
final_score = 0.0
if verbose:
print(f"\n{'='*60}")
print(f"Task {task_id}: {obs_dict['task_name']} [{obs_dict['task_id']}]")
print(f"{'='*60}")
for step_num in range(MAX_STEPS):
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": _build_user_message(obs_dict)},
]
try:
response = client.chat.completions.create(
model=MODEL,
messages=messages,
temperature=0.0,
max_tokens=1024,
)
content = response.choices[0].message.content.strip()
parsed = json.loads(content)
action = Action(
rewritten_query=parsed.get("rewritten_query", ""),
explanation=parsed.get("explanation", ""),
is_done=bool(parsed.get("is_done", False)),
)
except Exception as exc:
if verbose:
print(f" Step {step_num + 1}: LLM error β {exc}")
action = Action(
rewritten_query="",
explanation="error",
is_done=True,
)
obs, reward, done, info = env.step(action)
obs_dict = obs.model_dump()
final_score = info["grader_score"]
if verbose:
print(
f" Step {step_num + 1}: grader_score={info['grader_score']:.3f} "
f"step_reward={reward.score:.4f} feedback={reward.feedback[:80]}"
)
if done:
break
results[f"task_{task_id}_{env._task.name}"] = round(final_score, 4)
if verbose:
print(f" β Final grader score: {final_score:.4f}")
if verbose:
print(f"\n{'='*60}")
print("BASELINE RESULTS")
print(f"{'='*60}")
for k, v in results.items():
print(f" {k}: {v:.4f}")
avg = sum(results.values()) / len(results)
print(f" Average: {avg:.4f}")
return results
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="OpenEnv SQL Optimizer β Baseline Inference")
parser.add_argument(
"--json", action="store_true", help="Output results as JSON (used by /baseline endpoint)"
)
args = parser.parse_args()
scores = run_baseline(verbose=not args.json)
if args.json:
print(json.dumps(scores))
|