File size: 5,193 Bytes
de07414
 
 
 
 
 
 
8b003d5
 
 
 
de07414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b003d5
de07414
 
8b003d5
de07414
 
 
8b003d5
de07414
8b003d5
 
 
 
 
 
 
 
 
 
 
 
 
de07414
 
8b003d5
de07414
 
 
 
 
 
 
8b003d5
de07414
 
 
 
 
 
 
8b003d5
de07414
 
 
 
 
 
8b003d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de07414
8b003d5
de07414
 
 
 
 
1314b5a
 
8b003d5
 
de07414
8b003d5
de07414
8b003d5
de07414
8b003d5
de07414
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os
import json
from openai import OpenAI
from server.models import Action, BrowserGymAction # using our local Action model alias
from server.app import env_instance as env

# Environment Configuration
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o")
HF_TOKEN = os.getenv("HF_TOKEN")
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")

MAX_STEPS = 15
TEMPERATURE = 0.2
MAX_TOKENS = 512

SYSTEM_PROMPT = """
You are an expert Data Engineer interacting with a simulated SQLite database.
You will be given a task goal, the current database schema, and the most recent step's SQL output or error.
Your goal is to complete the task by executing SQL commands. 

CRITICAL RULES:
1. You may only execute ONE SQL statement at a time. Do not chain statements with semicolons.
2. If you need to review data, use short SELECT queries.
3. If your previous action resulted in an SQL error, fix the error and try again.
4. If you need multiple steps to achieve the goal (e.g. create tables, then insert data), execute them one by one.
5. You MUST output ONLY a valid JSON object matching this schema:
{
  "action_str": "YOUR SQL QUERY HERE"
}
Do not wrap your response in markdown code blocks. Just valid JSON.
"""

def build_user_prompt(step: int, observation, history: list) -> str:
    prompt = f"--- Step {step} ---\n"
    prompt += f"Goal: {observation.goal}\n\n"
    if observation.schema_dump:
        prompt += f"Current DB Schema:\n{observation.schema_dump}\n\n"
    
    prompt += f"Last Result (or Error):\n{observation.result}\n\n"
    
    if history:
        prompt += "Action History (Last 3 steps):\n"
        for h in history[-3:]:
            prompt += h + "\n"
            
    prompt += "\nProvide the JSON with your next `action_str`:"
    return prompt

def parse_model_action(response_text: str) -> str:
    text = response_text.strip()
    if text.startswith("```json"): text = text[7:]
    if text.startswith("```"): text = text[3:]
    if text.endswith("```"): text = text[:-3]
    text = text.strip()
    
    try:
        data = json.loads(text)
        return data.get("action_str", "SELECT 1;")
    except json.JSONDecodeError:
        return text

def run_task(task_id: int):    
    client = OpenAI(
        base_url=API_BASE_URL,
        api_key=HF_TOKEN
    )
    
    history = []
    rewards = []
    
    try:
        result = env.reset(task_id=task_id)
        observation = result.observation
        final_score = result.info.get("initial_score", 0.0)
    except Exception as e:
        print(f"[START] task={task_id} env=sql-data-engineer-env model={MODEL_NAME}")
        print(f"[END] success=false steps=0 score=0.00 rewards=")
        return 0.0

    print(f"[START] task={task_id} env=sql-data-engineer-env model={MODEL_NAME}")

    done = False
    step_count = 0

    for step in range(1, MAX_STEPS + 1):
        step_count = step
        user_prompt = build_user_prompt(step, observation, history)
        
        messages = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": user_prompt},
        ]

        action_str = ""
        try:
            completion = client.chat.completions.create(
                model=MODEL_NAME,
                messages=messages,
                temperature=TEMPERATURE,
                max_tokens=MAX_TOKENS,
                stream=False,
                response_format={"type": "json_object"} 
            )
            response_text = completion.choices[0].message.content or ""
            action_str = parse_model_action(response_text)
        except Exception as exc: 
            action_str = "SELECT 1;"

        try:
            step_result = env.step(BrowserGymAction(action_str=action_str))
            observation = step_result.observation
            reward = step_result.reward
            done = step_result.done
            final_score = step_result.info.get("current_score", 0.0)
            
            if observation.last_action_error:
                error_msg = observation.result.replace('\n', ' ')
            else:
                error_msg = "null"
        except Exception as e:
            reward = 0.0
            done = True
            error_msg = str(e).replace('\n', ' ')

        rewards.append(f"{reward:.2f}")
        
        done_str = "true" if done else "false"
        safe_action = action_str.replace('\n', ' ')
        err_out = f'"{error_msg}"' if error_msg != "null" else "null"
        
        print(f"[STEP] step={step} action=\"{safe_action}\" reward={reward:.2f} done={done_str} error={err_out}")

        history_line = f"Step {step}: {safe_action[:50]}... -> reward {reward:+.2f}"
        history.append(history_line)

        if done:
            break

    final_score = max(0.01, min(0.99, final_score))
    success_str = "true" if final_score >= 0.99 else "false"
    rewards_str = ",".join(rewards)
    print(f"[END] success={success_str} steps={step_count} score={final_score:.2f} rewards={rewards_str}")
    
    return final_score

def main():    
    for task_id in [1, 2, 3]:
        run_task(task_id)

if __name__ == "__main__":
    main()