File size: 7,694 Bytes
1f008d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ea2f5b
 
1f008d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
"""


Baseline Inference Script for SQL/Data Cleaning Sandbox  OpenAI Edition.





Uses OpenAI (gpt-4o) to solve all three tasks and prints reproducible


scores via the OpenEnv WebSocket client.





Usage:


    set HF_TOKEN=sk-...          # Windows


    export HF_TOKEN=sk-...       # Linux/macOS


    python inference.py                    # local server


    python inference.py --url https://...  # remote server


"""



import argparse

import json

import os

import sys



from dotenv import load_dotenv

load_dotenv()



from openai import OpenAI



from client import SqlSandboxEnv

from models import SqlSandboxAction





# ---------------------------------------------------------------------------

# System prompt shared across all tasks

# ---------------------------------------------------------------------------

SYSTEM_PROMPT = """\


You are a data engineering assistant working inside a SQLite sandbox.





You can execute two types of actions:


1. {"tool": "sql",    "command": "<SQL query>"}


2. {"tool": "python", "command": "<Python code>"}





Rules:


- Respond with EXACTLY ONE JSON object per turn  no markdown, no explanation.


- In Python code, the variables `conn` (sqlite3.Connection) and `cursor`


  (sqlite3.Cursor) are already available. Do NOT call sqlite3.connect().


- SQLite STRFTIME months are zero-padded: use '01' not '1', or use LIKE '2024-01-%'.


- When you believe the task is fully complete, send:


  {"tool": "sql", "command": "SELECT 'DONE'"}


"""





# ---------------------------------------------------------------------------

# Core agent loop  one task, one WebSocket session

# ---------------------------------------------------------------------------

def _run_task_agent(base_url: str, task_id: str, max_turns: int = 15) -> float:

    """


    Open a fresh WebSocket session, reset the environment to the given task,


    then run an LLM agent loop until done or max_turns is reached.


    Returns the final reward (0.0  1.0).


    """

    api_key = os.environ.get("HF_TOKEN") or os.environ.get("OPENAI_API_KEY")

    api_base_url = os.environ.get("API_BASE_URL")

    model_name = os.environ.get("MODEL_NAME", "gpt-4o")



    client_llm = OpenAI(

        api_key=api_key,

        base_url=api_base_url,

    )

    final_reward = 0.0



    # Each task gets its own WebSocket session to avoid state leakage

    with SqlSandboxEnv(base_url=base_url).sync() as env:

        # reset() with task_id seeds the correct DB table for this task

        reset_resp = env.reset(task_id=task_id)

        task_desc = reset_resp.observation.task_description



        messages = [

            {"role": "system", "content": SYSTEM_PROMPT},

            {"role": "user",   "content": f"Task: {task_desc}\n\nBegin."},

        ]



        print(f"\n  --- Session: {task_id} ---")



        for turn in range(max_turns):

            # 1. Ask the LLM

            response = client_llm.chat.completions.create(

                model=model_name,

                messages=messages,

                temperature=0.0,

                max_tokens=512,

            )

            assistant_msg = response.choices[0].message.content.strip()



            # 2. Parse action JSON (handle optional markdown fences)

            try:

                raw = assistant_msg

                if raw.startswith("```"):

                    raw = raw.split("```")[1]

                    if raw.startswith("json"):

                        raw = raw[4:]

                action_data = json.loads(raw)

                tool    = action_data["tool"]

                command = action_data["command"]

            except (json.JSONDecodeError, KeyError):

                # Feed parse error back to LLM, do NOT count as a step

                messages.append({"role": "assistant", "content": assistant_msg})

                messages.append({

                    "role": "user",

                    "content": (

                        'Invalid JSON. Reply with exactly one JSON object:\n'

                        '{"tool": "sql" | "python", "command": "..."}'

                    ),

                })

                continue



            # 3. Execute the action via OpenEnv step()

            step_resp = env.step(SqlSandboxAction(tool=tool, command=command))



            reward = step_resp.reward or 0.0

            done   = step_resp.done

            output = step_resp.observation.output or ""

            error  = step_resp.observation.error  or ""



            final_reward = reward

            print(f"  [Turn {turn+1:02d}] tool={tool:<6} | reward={reward:.4f} | done={done}")



            if done:

                break



            # 4. Feed result back to LLM for the next turn

            messages.append({"role": "assistant", "content": assistant_msg})

            feedback = f"Output:\n{output[:1500]}"

            if error:

                feedback += f"\nError:\n{error[:500]}"

            feedback += f"\nReward so far: {reward:.4f}"

            messages.append({"role": "user", "content": feedback})



    return final_reward





# ---------------------------------------------------------------------------

# Per-difficulty entry points (called by main, importable for custom use)

# ---------------------------------------------------------------------------

def easy_run(base_url: str, max_turns: int = 15) -> float:

    print(f"\n{'='*50}\nRunning task: easy\n{'='*50}")

    score = _run_task_agent(base_url, "easy", max_turns)

    print(f"  Final score: {score:.4f}")

    return score





def med_run(base_url: str, max_turns: int = 15) -> float:

    print(f"\n{'='*50}\nRunning task: medium\n{'='*50}")

    score = _run_task_agent(base_url, "medium", max_turns)

    print(f"  Final score: {score:.4f}")

    return score





def hard_run(base_url: str, max_turns: int = 15) -> float:

    print(f"\n{'='*50}\nRunning task: hard\n{'='*50}")

    score = _run_task_agent(base_url, "hard", max_turns)

    print(f"  Final score: {score:.4f}")

    return score





# ---------------------------------------------------------------------------

# CLI entry point

# ---------------------------------------------------------------------------

def main():

    parser = argparse.ArgumentParser(

        description="OpenAI baseline inference for the SQL/Data Cleaning Sandbox"

    )

    parser.add_argument(

        "--url",

        default="http://localhost:7860",

        help="Base URL of the running environment server (default: http://localhost:7860)",

    )

    parser.add_argument(

        "--max-turns",

        type=int,

        default=15,

        help="Maximum agent turns per task (default: 15)",

    )

    args = parser.parse_args()



    if not os.environ.get("HF_TOKEN") and not os.environ.get("OPENAI_API_KEY"):

        print("ERROR: HF_TOKEN (or OPENAI_API_KEY) environment variable is not set per checklist.")

        sys.exit(1)



    results: dict[str, float] = {}

    results["easy"]   = easy_run(args.url, args.max_turns)

    results["medium"] = med_run(args.url,  args.max_turns)

    results["hard"]   = hard_run(args.url, args.max_turns)



    avg = sum(results.values()) / len(results)

    print(f"\n{'='*50}")

    print("RESULTS SUMMARY")

    print(f"{'='*50}")

    for task_id, score in results.items():

        print(f"  {task_id:<10}: {score:.4f}")

    print(f"  {'average':<10}: {avg:.4f}")





if __name__ == "__main__":

    main()