File size: 6,687 Bytes
569613d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf45353
569613d
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import argparse
import json
import os
import sys
import textwrap
from typing import List, Optional

from dotenv import load_dotenv
load_dotenv()

from openai import OpenAI

from client import SqlSandboxEnv
from models import SqlSandboxAction

# ---------------------------------------------------------------------------
# Ensure required env vars have fallbacks so OpenAI client never gets None
# ---------------------------------------------------------------------------
API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") or os.getenv("OPENAI_API_KEY")
MODEL_NAME = os.getenv("MODEL_NAME") or "gpt-4o-mini"
BENCHMARK = "sql_sandbox"

SYSTEM_PROMPT = textwrap.dedent("""
You are a data engineering assistant working inside a SQLite sandbox.

You can execute two types of actions:
1. {"tool": "sql",    "command": "<SQL query>"}
2. {"tool": "python", "command": "<Python code>"}

Rules:
1 Respond with EXACTLY ONE JSON object per turn  no markdown, no explanation.
2 In Python code, the variables `conn` (sqlite3.Connection) and `cursor`
  (sqlite3.Cursor) are already available. Do NOT call sqlite3.connect().
3 SQLite STRFTIME months are zero-padded: use '01' not '1', or use LIKE '2024-01-%'.
4 When you believe the task is fully complete, send:
  {"tool": "sql", "command": "SELECT 'DONE'"}
""").strip()


def log_start(task: str, env: str, model: str) -> None:
    print(f"[START] task={task} env={env} model={model}", flush=True)


def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
    error_val = error.replace("\n", " ") if error else "null"
    done_val = str(done).lower()
    print(
        f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
        flush=True,
    )


def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
    rewards_str = ",".join(f"{r:.2f}" for r in rewards)
    print(f"[END] success={str(success).lower()} steps={steps} score={score:.2f} rewards={rewards_str}", flush=True)


def _run_task_agent(client_llm: OpenAI, base_url: str, task_id: str, max_turns: int = 15) -> float:
    rewards: List[float] = []
    step_count = 0
    final_score = 0.0

    # Fallback response for API failures
    fallback_action = '{"tool": "sql", "command": "SELECT \'DONE\'"}'

    with SqlSandboxEnv(base_url=base_url).sync() as env:
        try:
            reset_resp = env.reset(task_id=task_id)
            task_desc = reset_resp.observation.task_description
        except Exception as e:
            print(f"[DEBUG] env.reset() error for task {task_id}: {e}", flush=True)
            return 0.0

        messages = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user",   "content": f"Task: {task_desc}\n\nBegin."},
        ]

        log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)

        for turn in range(1, max_turns + 1):
            # 1. Ask the LLM, wrapped in try...except
            try:
                response = client_llm.chat.completions.create(
                    model=MODEL_NAME,
                    messages=messages,
                    temperature=0.0,
                    max_tokens=512,
                )
                assistant_msg = response.choices[0].message.content.strip()
            except Exception as exc:
                print(f"[DEBUG] Model request failed: {exc}", flush=True)
                assistant_msg = fallback_action

            # 2. Parse action JSON
            try:
                raw = assistant_msg
                if raw.startswith("```"):
                    raw = raw.split("```")[1]
                    if raw.startswith("json"):
                        raw = raw[4:]
                action_data = json.loads(raw)
                tool    = action_data["tool"]
                command = action_data["command"]
            except (json.JSONDecodeError, KeyError):
                # Feed parse error back to LLM, do NOT count as a step
                messages.append({"role": "assistant", "content": assistant_msg})
                messages.append({
                    "role": "user",
                    "content": (
                        'Invalid JSON. Reply with exactly one JSON object:\n'
                        '{"tool": "sql" | "python", "command": "..."}'
                    ),
                })
                continue

            # 3. Execute the action
            try:
                step_resp = env.step(SqlSandboxAction(tool=tool, command=command))
            except Exception as exc:
                print(f"[DEBUG] env.step() error: {exc}", flush=True)
                break

            reward = step_resp.reward or 0.0
            done   = step_resp.done
            output = step_resp.observation.output or ""
            error  = step_resp.observation.error  or ""

            rewards.append(reward)
            step_count += 1

            action_str = json.dumps({"tool": tool, "command": command})
            log_step(step=step_count, action=action_str, reward=reward, done=done, error=error)

            if done:
                break

            # 4. Feed result back to LLM for the next turn
            messages.append({"role": "assistant", "content": assistant_msg})
            feedback = f"Output:\n{output[:1500]}"
            if error:
                feedback += f"\nError:\n{error[:500]}"
            feedback += f"\nReward so far: {reward:.4f}"
            messages.append({"role": "user", "content": feedback})

        raw_score = sum(rewards)
        final_score = max(0.01, min(0.99, float(raw_score)))
        success = final_score >= 0.99

        log_end(success=success, steps=step_count, score=final_score, rewards=rewards)
        return final_score


def main():
    parser = argparse.ArgumentParser(
        description="OpenAI baseline inference for the SQL/Data Cleaning Sandbox"
    )
    parser.add_argument(
        "--url",
        default="http://localhost:7860",
        help="Base URL of the running environment server",
    )
    parser.add_argument(
        "--max-turns",
        type=int,
        default=15,
        help="Maximum agent turns per task (default: 15)",
    )
    args = parser.parse_args()

    if not API_KEY:
        print("ERROR: HF_TOKEN (or OPENAI_API_KEY) environment variable is not set.", flush=True)

    client_llm = OpenAI(
        api_key=API_KEY or "dummy_key",
        base_url=API_BASE_URL,
    )

    for task in [f"task{i}" for i in range(1, 7)]:
        _run_task_agent(client_llm, args.url, task, args.max_turns)


if __name__ == "__main__":
    main()