Meta-Pytorch-Openenv / inference.py
shreyas231219's picture
Upload folder using huggingface_hub
6ea2f5b verified
"""
Baseline Inference Script for SQL/Data Cleaning Sandbox OpenAI Edition.
Uses OpenAI (gpt-4o) to solve all three tasks and prints reproducible
scores via the OpenEnv WebSocket client.
Usage:
set HF_TOKEN=sk-... # Windows
export HF_TOKEN=sk-... # Linux/macOS
python inference.py # local server
python inference.py --url https://... # remote server
"""
import argparse
import json
import os
import sys
from dotenv import load_dotenv
load_dotenv()
from openai import OpenAI
from client import SqlSandboxEnv
from models import SqlSandboxAction
# ---------------------------------------------------------------------------
# System prompt shared across all tasks
# ---------------------------------------------------------------------------
SYSTEM_PROMPT = """\
You are a data engineering assistant working inside a SQLite sandbox.
You can execute two types of actions:
1. {"tool": "sql", "command": "<SQL query>"}
2. {"tool": "python", "command": "<Python code>"}
Rules:
- Respond with EXACTLY ONE JSON object per turn no markdown, no explanation.
- In Python code, the variables `conn` (sqlite3.Connection) and `cursor`
(sqlite3.Cursor) are already available. Do NOT call sqlite3.connect().
- SQLite STRFTIME months are zero-padded: use '01' not '1', or use LIKE '2024-01-%'.
- When you believe the task is fully complete, send:
{"tool": "sql", "command": "SELECT 'DONE'"}
"""
# ---------------------------------------------------------------------------
# Core agent loop one task, one WebSocket session
# ---------------------------------------------------------------------------
def _run_task_agent(base_url: str, task_id: str, max_turns: int = 15) -> float:
"""
Open a fresh WebSocket session, reset the environment to the given task,
then run an LLM agent loop until done or max_turns is reached.
Returns the final reward (0.0 1.0).
"""
api_key = os.environ.get("HF_TOKEN") or os.environ.get("OPENAI_API_KEY")
api_base_url = os.environ.get("API_BASE_URL")
model_name = os.environ.get("MODEL_NAME", "gpt-4o")
client_llm = OpenAI(
api_key=api_key,
base_url=api_base_url,
)
final_reward = 0.0
# Each task gets its own WebSocket session to avoid state leakage
with SqlSandboxEnv(base_url=base_url).sync() as env:
# reset() with task_id seeds the correct DB table for this task
reset_resp = env.reset(task_id=task_id)
task_desc = reset_resp.observation.task_description
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"Task: {task_desc}\n\nBegin."},
]
print(f"\n --- Session: {task_id} ---")
for turn in range(max_turns):
# 1. Ask the LLM
response = client_llm.chat.completions.create(
model=model_name,
messages=messages,
temperature=0.0,
max_tokens=512,
)
assistant_msg = response.choices[0].message.content.strip()
# 2. Parse action JSON (handle optional markdown fences)
try:
raw = assistant_msg
if raw.startswith("```"):
raw = raw.split("```")[1]
if raw.startswith("json"):
raw = raw[4:]
action_data = json.loads(raw)
tool = action_data["tool"]
command = action_data["command"]
except (json.JSONDecodeError, KeyError):
# Feed parse error back to LLM, do NOT count as a step
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({
"role": "user",
"content": (
'Invalid JSON. Reply with exactly one JSON object:\n'
'{"tool": "sql" | "python", "command": "..."}'
),
})
continue
# 3. Execute the action via OpenEnv step()
step_resp = env.step(SqlSandboxAction(tool=tool, command=command))
reward = step_resp.reward or 0.0
done = step_resp.done
output = step_resp.observation.output or ""
error = step_resp.observation.error or ""
final_reward = reward
print(f" [Turn {turn+1:02d}] tool={tool:<6} | reward={reward:.4f} | done={done}")
if done:
break
# 4. Feed result back to LLM for the next turn
messages.append({"role": "assistant", "content": assistant_msg})
feedback = f"Output:\n{output[:1500]}"
if error:
feedback += f"\nError:\n{error[:500]}"
feedback += f"\nReward so far: {reward:.4f}"
messages.append({"role": "user", "content": feedback})
return final_reward
# ---------------------------------------------------------------------------
# Per-difficulty entry points (called by main, importable for custom use)
# ---------------------------------------------------------------------------
def easy_run(base_url: str, max_turns: int = 15) -> float:
print(f"\n{'='*50}\nRunning task: easy\n{'='*50}")
score = _run_task_agent(base_url, "easy", max_turns)
print(f" Final score: {score:.4f}")
return score
def med_run(base_url: str, max_turns: int = 15) -> float:
print(f"\n{'='*50}\nRunning task: medium\n{'='*50}")
score = _run_task_agent(base_url, "medium", max_turns)
print(f" Final score: {score:.4f}")
return score
def hard_run(base_url: str, max_turns: int = 15) -> float:
print(f"\n{'='*50}\nRunning task: hard\n{'='*50}")
score = _run_task_agent(base_url, "hard", max_turns)
print(f" Final score: {score:.4f}")
return score
# ---------------------------------------------------------------------------
# CLI entry point
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(
description="OpenAI baseline inference for the SQL/Data Cleaning Sandbox"
)
parser.add_argument(
"--url",
default="http://localhost:7860",
help="Base URL of the running environment server (default: http://localhost:7860)",
)
parser.add_argument(
"--max-turns",
type=int,
default=15,
help="Maximum agent turns per task (default: 15)",
)
args = parser.parse_args()
if not os.environ.get("HF_TOKEN") and not os.environ.get("OPENAI_API_KEY"):
print("ERROR: HF_TOKEN (or OPENAI_API_KEY) environment variable is not set per checklist.")
sys.exit(1)
results: dict[str, float] = {}
results["easy"] = easy_run(args.url, args.max_turns)
results["medium"] = med_run(args.url, args.max_turns)
results["hard"] = hard_run(args.url, args.max_turns)
avg = sum(results.values()) / len(results)
print(f"\n{'='*50}")
print("RESULTS SUMMARY")
print(f"{'='*50}")
for task_id, score in results.items():
print(f" {task_id:<10}: {score:.4f}")
print(f" {'average':<10}: {avg:.4f}")
if __name__ == "__main__":
main()