sql-query-debugger / inference.py
gourav03003's picture
feat: achieve 1.0 baseline score across all tasks
2950d2e
import asyncio
import os
import textwrap
from typing import List, Optional
from openai import OpenAI
from models import SqlQueryDebuggerAction, SqlQueryDebuggerObservation
from server.sql_query_debugger_environment import SqlQueryDebuggerEnvironment, SCENARIOS
# ── env vars ────────────────────────────────────────────────────────
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY", "dummy")
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
BENCHMARK = os.getenv("BENCHMARK", "sql_query_debugger")
MAX_STEPS = 5
SUCCESS_THRESHOLD = 0.5
# ── logging helpers ─────────────────────────────────────────────────
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float,
done: bool, error: Optional[str]) -> None:
err = error if error else "null"
done_val = str(done).lower()
action_clean = action.replace("\n", " ").strip()[:120]
print(
f"[STEP] step={step} action={action_clean} "
f"reward={reward:.2f} done={done_val} error={err}",
flush=True,
)
def log_end(success: bool, steps: int,
score: float, rewards: List[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(
f"[END] success={str(success).lower()} steps={steps} "
f"score={score:.3f} rewards={rewards_str}",
flush=True,
)
# ── agent prompt ─────────────────────────────────────────────────────
SYSTEM_PROMPT = textwrap.dedent("""
You are an expert SQL debugger.
You will be given a broken SQL query, the database schema, an error message,
sample data rows, and a hint about what the correct output should be.
Your job is to return ONLY the fixed SQL query β€” nothing else.
No explanation, no markdown, no code blocks. Just the raw SQL query ending with a semicolon.
""").strip()
def build_user_prompt(obs: SqlQueryDebuggerObservation) -> str:
return textwrap.dedent(f"""
Task difficulty: {obs.task_id}
Database schema:
{obs.schema}
Sample rows:
{obs.sample_rows}
Broken query:
{obs.broken_query}
Error message:
{obs.error_message if obs.error_message else "No error β€” but output is wrong"}
Expected output hint:
{obs.expected_output_hint}
Last attempt result:
{obs.last_result if obs.last_result else "No attempt yet"}
Attempts remaining: {obs.attempts_remaining}
Return ONLY the fixed SQL query:
""").strip()
def get_fixed_query(client: OpenAI, obs: SqlQueryDebuggerObservation) -> str:
try:
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": build_user_prompt(obs)},
],
temperature=0.2,
max_tokens=256,
stream=False,
)
query = (response.choices[0].message.content or "").strip()
# strip markdown if model wraps in code block
if query.startswith("```"):
lines = query.split("\n")
query = "\n".join(
l for l in lines
if not l.startswith("```")
).strip()
return query if query else "SELECT 1;"
except Exception as exc:
print(f"[DEBUG] Model request failed: {exc}", flush=True)
return "SELECT 1;"
# ── one episode ──────────────────────────────────────────────────────
async def run_episode(task_id: str) -> float:
env = SqlQueryDebuggerEnvironment()
rewards: List[float] = []
steps_taken = 0
score = 0.0
success = False
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
try:
obs = env.reset(task_id=task_id)
for step in range(1, MAX_STEPS + 1):
if obs.done:
break
fixed_query = get_fixed_query(client, obs)
action = SqlQueryDebuggerAction(fixed_query=fixed_query)
obs = env.step(action)
current_reward = obs.reward or 0.0
rewards.append(current_reward)
steps_taken = step
log_step(
step = step,
action = fixed_query,
reward = current_reward,
done = obs.done,
error = obs.error_message if obs.error_message else None,
)
if obs.done:
break
# Calculate final metrics based on the episode results
if rewards:
# Score is the maximum reward reached (captures early solve bonuses)
score = max(rewards)
# success is true if any step reached the solution threshold
success = any(r >= 0.99 for r in rewards)
except Exception as e:
print(f"[DEBUG] Episode failed with error: {e}", flush=True)
finally:
# Mandatory: Always emit [END] log with correct formatting
final_score_clamped = min(max(score, 0.0), 1.0)
log_end(
success = success,
steps = steps_taken,
score = final_score_clamped,
rewards = rewards,
)
return score
# ── main: run all 3 tasks ────────────────────────────────────────────
async def main() -> None:
task_ids = ["syntax_fix", "logic_bug", "multi_table"]
all_scores = []
for task_id in task_ids:
score = await run_episode(task_id)
all_scores.append(score)
print(f"[DEBUG] {task_id} score: {score:.3f}", flush=True)
avg = sum(all_scores) / len(all_scores)
print(f"[DEBUG] Average score across all tasks: {avg:.3f}", flush=True)
if __name__ == "__main__":
asyncio.run(main())