File size: 9,281 Bytes
8d27c3e
 
9e89374
 
8d27c3e
 
 
b15226e
 
 
 
 
 
 
 
 
 
 
8d27c3e
c78c2fe
8d27c3e
 
9e89374
8d27c3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c78c2fe
 
 
 
8d27c3e
 
 
 
 
 
 
 
 
 
 
 
 
 
29473f6
8d27c3e
9e89374
 
 
 
 
29473f6
 
 
 
 
 
 
 
 
 
 
 
 
9e89374
 
29473f6
 
9e89374
 
29473f6
 
 
 
 
 
 
 
 
8d27c3e
b15226e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d27c3e
 
 
b15226e
8d27c3e
b15226e
 
 
 
 
8d27c3e
b15226e
 
 
 
8d27c3e
 
b15226e
 
 
 
8d27c3e
 
b15226e
 
8d27c3e
 
b15226e
 
8d27c3e
b15226e
8d27c3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b15226e
8d27c3e
b15226e
8d27c3e
 
 
 
b15226e
8d27c3e
 
 
 
 
 
 
 
 
 
 
b15226e
 
8d27c3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b15226e
 
 
8d27c3e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
import os
import asyncio
import json
import re
from openai import AsyncOpenAI

try:
    from models import DataWranglerAction
except (ImportError, ModuleNotFoundError):
    import sys
    sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
    from models import DataWranglerAction

API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-3.5-turbo")
HF_TOKEN = os.getenv("HF_TOKEN")
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "data_wrangler")
TASK_NAME = "data_wrangler_task"
BENCHMARK = "data_wrangler"
MAX_STEPS = 15
MAX_TOTAL_REWARD = 1.0
SUCCESS_SCORE_THRESHOLD = 0.5
MAX_HISTORY_ITEMS = int(os.environ.get("MAX_HISTORY_ITEMS", "6"))

system_prompt = """\
SYSTEM INSTRUCTIONS: ELITE DATA ENGINEER AGENT

ROLE AND PERSONA
You are an elite Data Engineering AI Agent operating within an automated data-wrangling pipeline. Your core function is to autonomously clean, format, and standardize messy, real-world datasets until they perfectly match a hidden "ground truth" target. You operate systematically, analytically, and with absolute precision.

MISSION OBJECTIVE
At each step, you will receive an Observation of the current data state. You must analyze the data anomalies (missing values, bad schemas, incorrect data types) and issue exactly ONE valid operation from your Action Space. You will iterate on this process until the dataset is perfectly clean, at which point you will issue the submit action.

THE OBSERVATION
You will receive a state dictionary detailing the dataset's current form:
columns: Current list of headers.
row_count: Total number of rows in the dataset.
column_stats: Dictionary mapping column names to {dtype, missing_count, sample_values}.
last_action_feedback: Status/error message resulting from your previous action.
is_done: Boolean termination flag.

ACTION SPACE (AVAILABLE TOOLS)
You have a strict, highly constrained toolset. Your chosen action MUST be a valid JSON object matching exactly ONE of the schemas:
1. Drop Column: {"action_type": "drop_column", "target_column": "..."}
2. Rename Column: {"action_type": "rename_column", "target_column": "...", "new_name": "..."}
3. Fill Missing Values: {"action_type": "fill_missing", "target_column": "...", "fill_value": "..."}
4. Cast Data Type: {"action_type": "cast_type", "target_column": "...", "cast_to": "..."}
5. Extract Regex: {"action_type": "extract_regex", "target_column": "...", "new_name": "...", "regex_pattern": "..."}
6. Parse Datetime: {"action_type": "datetime_parse", "target_column": "...", "format_string": "..."}
7. Group By & Aggregate: {"action_type": "group_by_aggregate", "target_column": "...", "agg_column": "...", "agg_func": "sum|mean|count"}
8. Submit: {"action_type": "submit"}

REQUIRED OUTPUT FORMAT (CHAIN OF THOUGHT)
<thinking>
Analyze Observation: What is the current state? What did the last action do?
Identify Anomalies: Which columns have wrong types, bad names, or missing data?
Formulate Plan: What is the highest priority fix right now?
Select Action: Which action type and parameters will execute this fix?
</thinking>
{
"action_type": "...",
...
}
"""

async def get_model_message(client, step, obs_dict, last_reward, history, max_retries=3):
    obs_text = str(obs_dict)
    trimmed_history = history[-MAX_HISTORY_ITEMS:] if history else []
    prompt = (
        f"Step {step}.\nObservation: {obs_text}\nLast Reward: {last_reward}\n"
        f"History: {trimmed_history}\nChoose your next action (JSON matching schema)."
    )
    
    # Priority 3: Error Reflection. Pass previous feedback directly to LLM if there was an error.
    if "Error" in obs_dict.get("last_action_feedback", "") or "Exception" in obs_dict.get("last_action_feedback", ""):
        prompt += f"\nCRITICAL: Your last action failed with this error: {obs_dict['last_action_feedback']}. Review your <thinking> block to correct your mistake before trying a new action."

    for attempt in range(max_retries):
        try:
            response = await client.chat.completions.create(
                model=MODEL_NAME,
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": prompt}
                ],
                temperature=0.0,
                max_tokens=220,
            )
            content = response.choices[0].message.content

            match = re.search(r'(\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\})', content or "", re.DOTALL)
            if match:
                return json.loads(match.group(1))
            else:
                prompt += f"\nWarning: Failed to extract JSON on attempt {attempt+1}. Provide ONLY valid JSON inside curly braces."
        except Exception as e:
            prompt += f"\nWarning: Exception on attempt {attempt+1}: {str(e)}. Provide valid JSON."
            
    # Fallback only if absolutely all retries fail
    return {"action_type": "submit"}

def _bool_str(value):
    return "true" if bool(value) else "false"


def _action_str(action):
    try:
        return json.dumps(action, separators=(",", ":"), ensure_ascii=False)
    except Exception:
        return str(action).replace("\n", " ")


def _reward_str(value):
    try:
        return f"{float(value):.2f}"
    except Exception:
        return "0.00"


def log_start(task, env, model):
    print(f"[START] task={task} env={env} model={model}")


def log_step(step, action, reward, done, error):
    error_str = "null" if error is None else str(error).replace("\n", " ")
    print(
        f"[STEP] step={step} action={_action_str(action)} "
        f"reward={_reward_str(reward)} done={_bool_str(done)} error={error_str}"
    )


def log_end(success, steps, rewards):
    rewards_csv = ",".join(_reward_str(r) for r in rewards)
    print(f"[END] success={_bool_str(success)} steps={steps} rewards={rewards_csv}")

async def main():
    log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)

    if not HF_TOKEN:
        log_end(success=False, steps=0, rewards=[])
        return

    client = AsyncOpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)

    try:
        from client import DataWranglerEnv
        env = DataWranglerEnv.from_docker_image(LOCAL_IMAGE_NAME)
    except Exception:
        from server.data_wrangler_environment import DataWranglerEnvironment
        env = DataWranglerEnvironment()

    history = []
    rewards = []
    steps_taken = 0
    score = 0.0
    success = False

    try:
        if hasattr(env, 'reset') and not asyncio.iscoroutinefunction(env.reset):
            result = env.reset()
        else:
            result = await env.reset() # OpenENV.reset() as per sample

        obs = getattr(result, "observation", result)
        obs_dict = {
            "columns": getattr(obs, "columns", []),
            "row_count": getattr(obs, "row_count", 0),
            "column_stats": getattr(obs, "column_stats", {}),
            "last_action_feedback": getattr(obs, "last_action_feedback", ""),
            "is_done": getattr(obs, "is_done", False)
        }
        last_reward = getattr(result, "reward", getattr(obs, "reward", 0.0)) or 0.0

        for step in range(1, MAX_STEPS + 1):
            done = getattr(result, "done", getattr(obs, "is_done", False))
            if done:
                break

            action_data = await get_model_message(client, step, obs_dict, last_reward, history)

            action_obj = DataWranglerAction(**action_data)

            if hasattr(env, 'step') and not asyncio.iscoroutinefunction(env.step):
                result = env.step(action_obj)
            else:
                result = await env.step(action_obj)

            obs = getattr(result, "observation", result)
            obs_dict = {
                "columns": getattr(obs, "columns", []),
                "row_count": getattr(obs, "row_count", 0),
                "column_stats": getattr(obs, "column_stats", {}),
                "last_action_feedback": getattr(obs, "last_action_feedback", ""),
                "is_done": getattr(obs, "is_done", False)
            }

            reward = getattr(result, "reward", getattr(obs, "reward", 0.0)) or 0.0
            done = getattr(result, "done", getattr(obs, "is_done", False))
            feedback = obs_dict.get("last_action_feedback", "")
            error = feedback if ("Error" in feedback or "Exception" in feedback) else None

            rewards.append(reward)
            steps_taken = step
            last_reward = reward

            log_step(step=step, action=action_data, reward=reward, done=done, error=error)

            history.append(f"Step {step}: {action_data} -> reward {reward:+.2f}")

            if done:
                break

        score = sum(rewards) / MAX_TOTAL_REWARD if MAX_TOTAL_REWARD > 0 else 0.0
        score = min(max(score, 0.0), 1.0)
        success = score >= SUCCESS_SCORE_THRESHOLD

    finally:
        try:
            if hasattr(env, 'close'):
                if asyncio.iscoroutinefunction(env.close):
                    await env.close()
                else:
                    env.close()
        except Exception as e:
            _ = e

        log_end(success=success, steps=steps_taken, rewards=rewards)

if __name__ == "__main__":
    asyncio.run(main())