File size: 7,694 Bytes
1f008d6 6ea2f5b 1f008d6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 | """
Baseline Inference Script for SQL/Data Cleaning Sandbox OpenAI Edition.
Uses OpenAI (gpt-4o) to solve all three tasks and prints reproducible
scores via the OpenEnv WebSocket client.
Usage:
set HF_TOKEN=sk-... # Windows
export HF_TOKEN=sk-... # Linux/macOS
python inference.py # local server
python inference.py --url https://... # remote server
"""
import argparse
import json
import os
import sys
from dotenv import load_dotenv
load_dotenv()
from openai import OpenAI
from client import SqlSandboxEnv
from models import SqlSandboxAction
# ---------------------------------------------------------------------------
# System prompt shared across all tasks
# ---------------------------------------------------------------------------
SYSTEM_PROMPT = """\
You are a data engineering assistant working inside a SQLite sandbox.
You can execute two types of actions:
1. {"tool": "sql", "command": "<SQL query>"}
2. {"tool": "python", "command": "<Python code>"}
Rules:
- Respond with EXACTLY ONE JSON object per turn no markdown, no explanation.
- In Python code, the variables `conn` (sqlite3.Connection) and `cursor`
(sqlite3.Cursor) are already available. Do NOT call sqlite3.connect().
- SQLite STRFTIME months are zero-padded: use '01' not '1', or use LIKE '2024-01-%'.
- When you believe the task is fully complete, send:
{"tool": "sql", "command": "SELECT 'DONE'"}
"""
# ---------------------------------------------------------------------------
# Core agent loop one task, one WebSocket session
# ---------------------------------------------------------------------------
def _run_task_agent(base_url: str, task_id: str, max_turns: int = 15) -> float:
"""
Open a fresh WebSocket session, reset the environment to the given task,
then run an LLM agent loop until done or max_turns is reached.
Returns the final reward (0.0 1.0).
"""
api_key = os.environ.get("HF_TOKEN") or os.environ.get("OPENAI_API_KEY")
api_base_url = os.environ.get("API_BASE_URL")
model_name = os.environ.get("MODEL_NAME", "gpt-4o")
client_llm = OpenAI(
api_key=api_key,
base_url=api_base_url,
)
final_reward = 0.0
# Each task gets its own WebSocket session to avoid state leakage
with SqlSandboxEnv(base_url=base_url).sync() as env:
# reset() with task_id seeds the correct DB table for this task
reset_resp = env.reset(task_id=task_id)
task_desc = reset_resp.observation.task_description
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"Task: {task_desc}\n\nBegin."},
]
print(f"\n --- Session: {task_id} ---")
for turn in range(max_turns):
# 1. Ask the LLM
response = client_llm.chat.completions.create(
model=model_name,
messages=messages,
temperature=0.0,
max_tokens=512,
)
assistant_msg = response.choices[0].message.content.strip()
# 2. Parse action JSON (handle optional markdown fences)
try:
raw = assistant_msg
if raw.startswith("```"):
raw = raw.split("```")[1]
if raw.startswith("json"):
raw = raw[4:]
action_data = json.loads(raw)
tool = action_data["tool"]
command = action_data["command"]
except (json.JSONDecodeError, KeyError):
# Feed parse error back to LLM, do NOT count as a step
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({
"role": "user",
"content": (
'Invalid JSON. Reply with exactly one JSON object:\n'
'{"tool": "sql" | "python", "command": "..."}'
),
})
continue
# 3. Execute the action via OpenEnv step()
step_resp = env.step(SqlSandboxAction(tool=tool, command=command))
reward = step_resp.reward or 0.0
done = step_resp.done
output = step_resp.observation.output or ""
error = step_resp.observation.error or ""
final_reward = reward
print(f" [Turn {turn+1:02d}] tool={tool:<6} | reward={reward:.4f} | done={done}")
if done:
break
# 4. Feed result back to LLM for the next turn
messages.append({"role": "assistant", "content": assistant_msg})
feedback = f"Output:\n{output[:1500]}"
if error:
feedback += f"\nError:\n{error[:500]}"
feedback += f"\nReward so far: {reward:.4f}"
messages.append({"role": "user", "content": feedback})
return final_reward
# ---------------------------------------------------------------------------
# Per-difficulty entry points (called by main, importable for custom use)
# ---------------------------------------------------------------------------
def easy_run(base_url: str, max_turns: int = 15) -> float:
print(f"\n{'='*50}\nRunning task: easy\n{'='*50}")
score = _run_task_agent(base_url, "easy", max_turns)
print(f" Final score: {score:.4f}")
return score
def med_run(base_url: str, max_turns: int = 15) -> float:
print(f"\n{'='*50}\nRunning task: medium\n{'='*50}")
score = _run_task_agent(base_url, "medium", max_turns)
print(f" Final score: {score:.4f}")
return score
def hard_run(base_url: str, max_turns: int = 15) -> float:
print(f"\n{'='*50}\nRunning task: hard\n{'='*50}")
score = _run_task_agent(base_url, "hard", max_turns)
print(f" Final score: {score:.4f}")
return score
# ---------------------------------------------------------------------------
# CLI entry point
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(
description="OpenAI baseline inference for the SQL/Data Cleaning Sandbox"
)
parser.add_argument(
"--url",
default="http://localhost:7860",
help="Base URL of the running environment server (default: http://localhost:7860)",
)
parser.add_argument(
"--max-turns",
type=int,
default=15,
help="Maximum agent turns per task (default: 15)",
)
args = parser.parse_args()
if not os.environ.get("HF_TOKEN") and not os.environ.get("OPENAI_API_KEY"):
print("ERROR: HF_TOKEN (or OPENAI_API_KEY) environment variable is not set per checklist.")
sys.exit(1)
results: dict[str, float] = {}
results["easy"] = easy_run(args.url, args.max_turns)
results["medium"] = med_run(args.url, args.max_turns)
results["hard"] = hard_run(args.url, args.max_turns)
avg = sum(results.values()) / len(results)
print(f"\n{'='*50}")
print("RESULTS SUMMARY")
print(f"{'='*50}")
for task_id, score in results.items():
print(f" {task_id:<10}: {score:.4f}")
print(f" {'average':<10}: {avg:.4f}")
if __name__ == "__main__":
main()
|