""" Baseline inference script for the SQL Query Optimizer OpenEnv. Produces reproducible baseline scores on all 3 tasks using deterministic hardcoded optimal rewrites. Optionally uses the OpenAI API if OPENAI_API_KEY is set. Prints structured [START]/[STEP]/[END] output to stdout as required by the OpenEnv validation pipeline. Usage: python inference.py # or with LLM: OPENAI_API_KEY=sk-... python inference.py """ import os import sys from env.environment import SQLEnv from env.models import Action from env.tasks import TASKS # Deterministic baseline rewrites that score well on the graders BASELINE_REWRITES = { 1: "SELECT users.name, orders.amount FROM users JOIN orders ON users.id = orders.user_id;", 2: "SELECT e.name FROM employees e JOIN departments d ON e.dept_id = d.id WHERE d.name = 'Engineering';", 3: "SELECT s.id, s.product_id, s.sale_date, s.amount FROM sales s /* USE INDEX (idx_sales_date) */ WHERE s.sale_date = '2023-01-01';", } def get_rewrite_llm(obs, task_id: int) -> str: """Try to get a rewrite from the OpenAI API; fall back to baseline.""" api_key = os.environ.get("OPENAI_API_KEY") if not api_key: return BASELINE_REWRITES[task_id] try: from openai import OpenAI client = OpenAI(api_key=api_key) messages = [ { "role": "system", "content": ( "You are an expert SQL DBA. You rewrite SQL queries " "to be correct, optimized, and performant." ), }, { "role": "user", "content": ( f"Task #{obs.task_id}\n" f"Original Query: {obs.query}\n" f"Database Schema Context: {obs.schema_context}\n" f"Hint: {obs.hint}\n\n" "Please provide the optimized query. " "Output ONLY the raw SQL query, no markdown formatting, no explanation." ), }, ] response = client.chat.completions.create( model="gpt-3.5-turbo", messages=messages, temperature=0.0, ) rewritten = response.choices[0].message.content.strip() # Strip markdown fences if present if rewritten.startswith("```sql"): rewritten = rewritten[6:] if rewritten.startswith("```"): rewritten = rewritten[3:] if rewritten.endswith("```"): rewritten = rewritten[:-3] return rewritten.strip() except Exception as e: print(f"LLM call failed ({e}), using deterministic baseline", flush=True) return BASELINE_REWRITES[task_id] def run_task(env: SQLEnv, task_id: int, task_name: str) -> float: """Run a single task and print structured output.""" print(f"[START] task={task_name}", flush=True) obs = env.reset(task_id=task_id) rewritten_query = get_rewrite_llm(obs, task_id) action = Action( rewritten_query=rewritten_query, explanation="Baseline inference rewrite", is_done=True, ) result_obs = env.step(action) reward = result_obs.reward grader_score = env.final_grader_score step_count = env.step_number - 1 # step_number was incremented after step() print(f"[STEP] step=1 reward={reward}", flush=True) print(f"[END] task={task_name} score={grader_score} steps={step_count}", flush=True) return grader_score def main(): env = SQLEnv() scores = {} for task_id, task_info in TASKS.items(): task_name = task_info["name"] score = run_task(env, task_id, task_name) scores[task_id] = score # Summary print("\n=== Baseline Evaluation Results ===", flush=True) for tid, score in scores.items(): print(f" Task {tid} ({TASKS[tid]['name']}): {score}/1.0", flush=True) if __name__ == "__main__": main()