File size: 9,503 Bytes
7d21a80
 
 
c15d346
7d21a80
c15d346
 
 
7d21a80
 
c15d346
 
 
7d21a80
 
 
c15d346
7d21a80
c15d346
 
7d21a80
 
 
 
 
 
 
 
c15d346
7d21a80
 
 
 
c15d346
 
 
7d21a80
 
 
c15d346
 
 
 
7d21a80
 
c15d346
7d21a80
c15d346
 
 
 
 
 
 
 
 
 
 
 
 
 
7d21a80
 
 
c15d346
 
 
7d21a80
c15d346
7d21a80
 
c15d346
 
 
7d21a80
 
 
 
c15d346
7d21a80
 
 
 
 
c15d346
 
 
7d21a80
c15d346
 
7d21a80
 
 
 
 
c15d346
7d21a80
c15d346
 
7d21a80
 
 
 
c15d346
7d21a80
c15d346
 
7d21a80
 
c15d346
 
 
7d21a80
 
 
 
 
 
c15d346
 
 
7d21a80
c15d346
7d21a80
 
 
c15d346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d21a80
c15d346
 
 
 
 
7d21a80
c15d346
 
 
 
 
 
 
 
 
 
 
7d21a80
 
c15d346
7d21a80
c15d346
7d21a80
 
 
c15d346
7d21a80
 
 
 
 
c15d346
7d21a80
 
c15d346
 
7d21a80
c15d346
7d21a80
 
c15d346
7d21a80
 
 
c15d346
7d21a80
 
 
 
c15d346
7d21a80
 
 
 
 
 
 
 
 
 
 
 
c15d346
7d21a80
 
 
 
 
 
 
 
 
 
 
 
 
c15d346
 
 
 
 
 
 
 
 
 
7d21a80
 
 
 
c15d346
 
7d21a80
 
 
 
 
 
 
 
c15d346
 
7d21a80
 
c15d346
7d21a80
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
"""
inference.py β€” SQL Query Optimization Environment
===================================================
Multi-step inference loop with execution-feedback awareness.

When the environment returns execution results from a previous step,
the agent uses them to REFINE its optimized query β€” creating a genuine
iterative optimization loop grounded in real performance data.

stdout format (strictly followed):
  [START] task=<task_id> env=sql-optim-env model=<MODEL_NAME>
  [STEP]  step=<n> action=<summary> reward=<0.00> done=<bool> error=<msg|null>
  [END]   success=<bool> steps=<n> score=<score> rewards=<r1,...,rn>
"""

import json
import os
import sys
from typing import Dict, List, Optional

from openai import OpenAI

ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, ROOT_DIR)

from env import SQLOptimEnv
from models import Action

# ── Config ────────────────────────────────────────────────────────────────
API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME   = os.environ.get("MODEL_NAME",   "Qwen/Qwen2.5-72B-Instruct")
HF_TOKEN     = os.environ.get("HF_TOKEN", "") or os.environ.get("API_KEY", "")

BENCHMARK   = "sql-optim-env"
TEMPERATURE = 0.0
MAX_TOKENS  = 2000

TASK_IDS = [
    "task_1_basic_antipatterns",
    "task_2_correlated_subqueries",
    "task_3_wildcard_scan",
    "task_4_implicit_join",
    "task_5_window_functions",
]

# ── System prompt ─────────────────────────────────────────────────────────
SYSTEM_PROMPT = """\
You are an elite database engineer and SQL performance specialist with expert-level \
knowledge of PostgreSQL/DuckDB internals, query planning, columnar storage, \
and index design.

You will receive a SQL query and its schema. Your job:
1. Identify ALL performance anti-patterns.
2. Produce a complete, correct, optimized rewrite.
3. Your optimized_query will be ACTUALLY EXECUTED against a DuckDB database \
   with realistic data (orders=500k rows, events=1M rows). \
   If it returns wrong results or errors, your score drops.
4. If you receive execution feedback from a previous step, USE IT to refine \
   your rewrite β€” fix incorrect results first, then improve speed.

Respond ONLY with valid JSON (no markdown, no fences):
{
  "suggestions": [
    {
      "issue_type": "e.g. select_star / correlated_subquery / wildcard_like",
      "line": <integer>,
      "description": "precise explanation of the performance problem",
      "severity": "critical | high | medium | low",
      "fix": "specific rewrite or corrective SQL"
    }
  ],
  "optimized_query": "<complete, executable SQL that produces IDENTICAL results to original>",
  "summary": "2-4 sentence performance profile of the original query",
  "estimated_improvement": "e.g. '15x faster β€” eliminates N+1 subquery pattern'",
  "approved": false
}
"""

# ── Logging (strict OpenEnv format) ──────────────────────────────────────

def log_start(task: str, env: str, model: str) -> None:
    print(f"[START] task={task} env={env} model={model}", flush=True)


def log_step(
    step: int, action: str, reward: float, done: bool, error: Optional[str]
) -> None:
    print(
        f"[STEP] step={step} action={action} reward={reward:.2f} "
        f"done={str(done).lower()} error={error or 'null'}",
        flush=True,
    )


def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
    rstr = ",".join(f"{r:.2f}" for r in rewards)
    print(
        f"[END] success={str(success).lower()} steps={steps} "
        f"score={score:.2f} rewards={rstr}",
        flush=True,
    )


# ── Model interaction ─────────────────────────────────────────────────────

def parse_action(text: str) -> Dict:
    clean = text.strip()
    if clean.startswith("```"):
        lines = clean.split("\n")
        clean = "\n".join(
            lines[1:-1] if lines[-1].strip() == "```" else lines[1:]
        )
        if clean.startswith("json"):
            clean = clean[4:].strip()
    try:
        return json.loads(clean)
    except json.JSONDecodeError:
        return {
            "suggestions":          [],
            "optimized_query":      "",
            "summary":              "Parse error β€” model returned malformed JSON.",
            "estimated_improvement": "unknown",
            "approved":             False,
        }


def build_user_prompt(obs) -> str:
    exec_feedback = ""
    if obs.last_execution:
        ex = obs.last_execution
        exec_feedback = (
            f"\n\n⚑ EXECUTION FEEDBACK FROM YOUR LAST OPTIMIZED QUERY:\n"
            f"  Original query  : {ex.get('original_ms', '?'):.1f} ms "
            f"  ({ex.get('original_rows', 0)} rows)\n"
            f"  Your last query : {ex.get('optimized_ms', '?'):.1f} ms "
            f"  ({ex.get('optimized_rows', 0)} rows)\n"
            f"  Speedup achieved: {ex.get('speedup', 1.0):.2f}x\n"
            f"  Results match   : {'βœ… YES' if ex.get('results_match') else '❌ NO β€” fix your WHERE/JOIN logic'}\n"
            f"  Verdict         : {ex.get('verdict', '')}\n"
            f"Refine your optimized_query to fix any correctness issues first, "
            f"then improve speed further."
        )

    issues_ctx = ""
    if obs.issues_found_so_far:
        issues_ctx = (
            f"\nIssue types you've already flagged: {obs.issues_found_so_far}"
        )

    return (
        f"Task        : {obs.task_name}\n"
        f"Difficulty  : {obs.difficulty}\n"
        f"Step        : {obs.step_count + 1} / {obs.max_steps}\n\n"
        f"Instructions:\n{obs.task_description}\n\n"
        f"Database Schema:\n{obs.schema_info}\n\n"
        f"SQL Query to Optimize:\n```sql\n{obs.sql_query}\n```"
        f"{issues_ctx}"
        f"{exec_feedback}\n\n"
        f"Provide your complete analysis and optimized_query now."
    )


def call_model(client: OpenAI, obs) -> tuple:
    try:
        resp = client.chat.completions.create(
            model=MODEL_NAME,
            messages=[
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user",   "content": build_user_prompt(obs)},
            ],
            temperature=TEMPERATURE,
            max_tokens=MAX_TOKENS,
            stream=False,
        )
        return parse_action(resp.choices[0].message.content or ""), None
    except Exception as exc:
        return {
            "suggestions": [], "optimized_query": "", "approved": False,
            "summary": f"Model error: {exc}",
            "estimated_improvement": "unknown",
        }, str(exc)


# ── Main loop ─────────────────────────────────────────────────────────────

def main():
    if not HF_TOKEN:
        print("[ERROR] HF_TOKEN not set.", flush=True)
        sys.exit(1)

    client = OpenAI(api_key=HF_TOKEN, base_url=API_BASE_URL)
    local_env = SQLOptimEnv()
    results: Dict[str, Dict] = {}

    for task_id in TASK_IDS:
        obs = local_env.reset(task_id=task_id)
        log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)

        rewards: List[float] = []
        steps_taken = 0
        score = 0.0
        success = False

        try:
            for step in range(1, obs.max_steps + 1):
                parsed, error = call_model(client, obs)

                action = Action(
                    suggestions=parsed.get("suggestions", []),
                    optimized_query=parsed.get("optimized_query", ""),
                    summary=parsed.get("summary", ""),
                    estimated_improvement=parsed.get("estimated_improvement", ""),
                    approved=parsed.get("approved", False),
                )

                result = local_env.step(action)
                reward = result.reward.score
                done = result.done

                # Pull execution info for the action summary
                exec_info = result.info.get("execution") or {}
                speedup = exec_info.get("speedup", 1.0)
                correct = exec_info.get("results_match", False)
                action_summary = (
                    f"suggestions={len(action.suggestions)},"
                    f"speedup={speedup:.2f}x,"
                    f"correct={str(correct).lower()}"
                )

                rewards.append(reward)
                steps_taken = step
                obs = result.observation

                log_step(step=step, action=action_summary,
                         reward=reward, done=done, error=error)

                if done:
                    break

            score = max(rewards) if rewards else 0.0
            success = score >= 0.5

        finally:
            log_end(success=success, steps=steps_taken,
                    score=score, rewards=rewards)

        results[task_id] = {
            "task_name":   obs.task_name,
            "final_score": round(score, 4),
            "steps_taken": steps_taken,
        }

    return results


if __name__ == "__main__":
    main()