File size: 6,365 Bytes
dce68a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c22bf49
 
dce68a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c22bf49
dce68a7
 
c22bf49
dce68a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c22bf49
dce68a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c22bf49
 
dce68a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c22bf49
dce68a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
"""
STDOUT FORMAT (must match exactly):
[START] task=<task_name> env=data_cleaning_env model=<model_name>
[STEP]  step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
[END]   success=<true|false> steps=<n> rewards=<r1,r2,...,rn>
"""

import json
import os

from openai import OpenAI

from env.environment import DataCleaningEnv
from env.graders import DataCleaningGrader
from env.models import Action

HF_TOKEN = os.getenv("HF_TOKEN")
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-120b")
BENCHMARK = "data_cleaning_env"

TASKS = ["basic_cleaning", "moderate_cleaning", "full_pipeline"]

SYSTEM_PROMPT = """You are an AI agent performing data cleaning on a tabular dataset.

You will receive an observation containing:
- data_preview: first 5 rows of the current dataset
- columns: column info (name, dtype, null_count, unique_count)
- pending_issues: list of issues to fix (each has issue_id, issue_type, column, description, depends_on)
- resolved_issues: issues already fixed
- action_history: your previous actions
- quality_score: current data quality (0.0-1.0)
- steps_remaining: how many actions you have left

You must respond with EXACTLY one JSON object representing your action:
{
    "action_type": "<one of: fill_missing, drop_duplicates, convert_dtype, normalize_category, create_feature>",
    "column": "<target column name or __all__ for drop_duplicates>",
    "params": {<strategy-specific params>}
}

Rules:
- fill_missing: params must have "strategy" key. Use "mean"/"median"/"zero" for numeric columns, "mode"/"unknown" for categorical.
- drop_duplicates: column = "__all__", params = {}
- convert_dtype: params must have "target_dtype" key (one of: int, float, str, bool)
- normalize_category: params = {}
- create_feature: params must have "feature_name" key (e.g., "age_group")

IMPORTANT: Fix dependencies first! Check the "depends_on" field of each issue. For example, fill missing string values in a column BEFORE converting its dtype.

Respond with ONLY the JSON object. No explanation, no markdown, no code blocks."""


def parse_action(response_text: str) -> Action:
    text = response_text.strip()
    if text.startswith("```"):
        parts = text.split("\n", 1)
        text = parts[1] if len(parts) > 1 else text[3:]
        if text.endswith("```"):
            text = text[:-3]
        text = text.strip()
    if text.startswith("json"):
        text = text[4:].strip()
    parsed = json.loads(text)
    return Action(**parsed)


def require_env(name: str, value: str | None) -> str:
    if value:
        return value
    raise RuntimeError(f"Missing required environment variable: {name}")


def safe_log_value(value: str | None) -> str:
    if not value:
        return "null"
    return str(value).replace("\n", "_").replace("\r", "_").replace("\t", "_").replace(" ", "_")


def log_start(task, env, model):
    print(f"[START] task={task} env={env} model={model}", flush=True)


def log_step(step, action_str, reward, done, error):
    error_val = safe_log_value(error)
    done_val = str(done).lower()
    print(
        f"[STEP] step={step} action={safe_log_value(action_str)} reward={reward:.2f} "
        f"done={done_val} error={error_val}",
        flush=True,
    )


def log_end(success, steps, rewards):
    rewards_str = ",".join(f"{reward:.2f}" for reward in rewards)
    success_val = str(success).lower()
    print(f"[END] success={success_val} steps={steps} rewards={rewards_str}", flush=True)


def run_task(task_name: str):
    client = OpenAI(
        base_url=require_env("API_BASE_URL", API_BASE_URL),
        api_key=require_env("HF_TOKEN", HF_TOKEN),
    )
    env = DataCleaningEnv(task_name=task_name)
    obs = env.reset()
    log_start(task_name, BENCHMARK, require_env("MODEL_NAME", MODEL_NAME))

    messages = [{"role": "system", "content": SYSTEM_PROMPT}]
    rewards_list = []
    step_count = 0
    done = False
    max_possible_steps = obs.steps_remaining
    task_score = 0.0

    while not done and step_count < max_possible_steps:
        obs_dict = obs.model_dump() if hasattr(obs, "model_dump") else obs.dict()
        messages.append(
            {
                "role": "user",
                "content": f"Current observation:\n{json.dumps(obs_dict, indent=2, default=str)}\n\nChoose your next action.",
            }
        )

        try:
            response = client.chat.completions.create(
                model=require_env("MODEL_NAME", MODEL_NAME),
                messages=messages,
                temperature=0.0,
                max_tokens=200,
            )
            response_text = response.choices[0].message.content or ""
            messages.append({"role": "assistant", "content": response_text})

            action = parse_action(response_text)
            obs, reward, done, info = env.step(action)
            step_count += 1
            last_error = info.get("error")
            rewards_list.append(reward)

            action_str = f"{action.action_type}({action.column})"
            log_step(step_count, action_str, reward, done, last_error)

        except Exception as exc:
            step_count += 1
            rewards_list.append(0.01)
            log_step(step_count, "parse_error", 0.01, False, str(exc))
            messages.append(
                {
                    "role": "user",
                    "content": f"Your response could not be parsed. Error: {str(exc)}. Respond with ONLY a valid JSON action object.",
                }
            )
            if step_count >= max_possible_steps:
                break

    success = hasattr(obs, "pending_issues") and len(obs.pending_issues) == 0
    final_state = obs.model_dump() if hasattr(obs, "model_dump") else obs.dict()
    task_score = DataCleaningGrader().grade(
        final_state,
        {
            "total_issues": final_state["total_issues_at_start"],
            "max_steps": max_possible_steps,
        },
    )
    log_end(success, step_count, rewards_list)
    return task_score


def main():
    require_env("HF_TOKEN", HF_TOKEN)
    require_env("API_BASE_URL", API_BASE_URL)
    require_env("MODEL_NAME", MODEL_NAME)
    scores = {}
    for task in TASKS:
        scores[task] = run_task(task)
    return scores


if __name__ == "__main__":
    main()