File size: 4,937 Bytes
210535c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""
Baseline inference script for the SQL Query Optimizer OpenEnv environment.

Usage:
    python baseline.py              # human-readable output
    python baseline.py --json       # JSON output (used by /baseline endpoint)

Requires:
    OPENAI_API_KEY environment variable

The script runs gpt-4o-mini against all 3 tasks and reports grader scores.
"""
from __future__ import annotations

import argparse
import json
import os
import sys

from openai import OpenAI

# ── import env from local package ──────────────────────────────────────────
sys.path.insert(0, os.path.dirname(__file__))
from env.environment import SQLOptimizerEnv
from env.models import Action

# ──────────────────────────────────────────────────────────────────────────────
MODEL = "gpt-4o-mini"
MAX_STEPS = 5
TASKS = [1, 2, 3]

SYSTEM_PROMPT = """You are a database performance engineer.
You will receive a broken or unoptimised SQL query along with table schema context.
Your job is to rewrite the query so it is correct and performant.

Respond ONLY with a JSON object with these exact keys:
{
  "rewritten_query": "<your improved SQL>",
  "explanation": "<brief explanation of changes>",
  "is_done": true
}
Do not wrap in markdown. Output raw JSON only."""


def _build_user_message(obs_dict: dict) -> str:
    return (
        f"Task: {obs_dict['task_name']} ({obs_dict['task_id']} β€” difficulty: "
        f"{obs_dict.get('difficulty', 'unknown')})\n\n"
        f"Description:\n{obs_dict['task_description']}\n\n"
        f"Schema:\n{obs_dict['schema_context']}\n\n"
        f"Query to fix:\n{obs_dict['query']}"
        + (f"\n\nHint: {obs_dict['hint']}" if obs_dict.get("hint") else "")
    )


def run_baseline(verbose: bool = True) -> dict[str, float]:
    api_key = os.getenv("OPENAI_API_KEY")
    if not api_key:
        print("ERROR: OPENAI_API_KEY is not set.", file=sys.stderr)
        sys.exit(1)

    client = OpenAI(api_key=api_key)
    env = SQLOptimizerEnv()
    results: dict[str, float] = {}

    for task_id in TASKS:
        obs = env.reset(task_id=task_id)
        obs_dict = obs.model_dump()
        final_score = 0.0

        if verbose:
            print(f"\n{'='*60}")
            print(f"Task {task_id}: {obs_dict['task_name']} [{obs_dict['task_id']}]")
            print(f"{'='*60}")

        for step_num in range(MAX_STEPS):
            messages = [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": _build_user_message(obs_dict)},
            ]

            try:
                response = client.chat.completions.create(
                    model=MODEL,
                    messages=messages,
                    temperature=0.0,
                    max_tokens=1024,
                )
                content = response.choices[0].message.content.strip()
                parsed = json.loads(content)
                action = Action(
                    rewritten_query=parsed.get("rewritten_query", ""),
                    explanation=parsed.get("explanation", ""),
                    is_done=bool(parsed.get("is_done", False)),
                )
            except Exception as exc:
                if verbose:
                    print(f"  Step {step_num + 1}: LLM error β€” {exc}")
                action = Action(
                    rewritten_query="",
                    explanation="error",
                    is_done=True,
                )

            obs, reward, done, info = env.step(action)
            obs_dict = obs.model_dump()
            final_score = info["grader_score"]

            if verbose:
                print(
                    f"  Step {step_num + 1}: grader_score={info['grader_score']:.3f}  "
                    f"step_reward={reward.score:.4f}  feedback={reward.feedback[:80]}"
                )

            if done:
                break

        results[f"task_{task_id}_{env._task.name}"] = round(final_score, 4)

        if verbose:
            print(f"  β†’ Final grader score: {final_score:.4f}")

    if verbose:
        print(f"\n{'='*60}")
        print("BASELINE RESULTS")
        print(f"{'='*60}")
        for k, v in results.items():
            print(f"  {k}: {v:.4f}")
        avg = sum(results.values()) / len(results)
        print(f"  Average: {avg:.4f}")

    return results


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="OpenEnv SQL Optimizer β€” Baseline Inference")
    parser.add_argument(
        "--json", action="store_true", help="Output results as JSON (used by /baseline endpoint)"
    )
    args = parser.parse_args()

    scores = run_baseline(verbose=not args.json)
    if args.json:
        print(json.dumps(scores))