File size: 5,612 Bytes
d7cc083
fa68c00
d7cc083
 
 
 
 
fa68c00
b871adb
 
 
fa68c00
 
d7cc083
 
 
 
 
 
fa68c00
d7cc083
 
fa68c00
 
 
d7cc083
 
 
 
 
 
 
 
fa68c00
 
 
 
 
d7cc083
 
 
 
 
 
 
 
fa68c00
 
 
 
 
 
 
 
 
 
d7cc083
 
fa68c00
 
d7cc083
 
 
 
 
 
 
 
 
fa68c00
d7cc083
 
 
 
 
fa68c00
d7cc083
 
fa68c00
d7cc083
fa68c00
 
 
d7cc083
fa68c00
 
 
d7cc083
fa68c00
 
d7cc083
fa68c00
 
 
 
 
d7cc083
fa68c00
 
d7cc083
fa68c00
d7cc083
 
 
fa68c00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7cc083
 
a37a022
d7cc083
 
 
fa68c00
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""
inference.py β€” OpenEnv submission file
"""
import os, json, sys
from openai import OpenAI
from data_cleaning_env import DataCleaningEnvironment, CleaningAction

# ── Config ────────────────────────────────────────────────────────────────────
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
MODEL_NAME   = os.getenv("MODEL_NAME",   "openai/gpt-oss-120b")  # Groq model name
HF_TOKEN     = os.getenv("HF_TOKEN")  
if HF_TOKEN is None:
    raise ValueError("HF_TOKEN environment variable is required")

client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)

SYSTEM_PROMPT = (
    "You are a data cleaning expert. "
    "Respond ONLY with a valid JSON object, no markdown, no explanation.\n"
    'Format: {"action_type": "<remove_nulls|fix_dates|remove_outliers>", "column": "<col_or_null>"}'
)

TASK_NAMES = {1: "remove_nulls", 2: "fix_dates", 3: "remove_outliers"}
ENV_NAME   = "data_cleaning"


def parse_llm_response(text: str, task_id: int) -> CleaningAction:
    text = text.strip().replace("```json", "").replace("```", "").strip()
    try:
        data = json.loads(text)
        action_type = data.get("action_type", "remove_nulls")
        if action_type not in ["remove_nulls", "fix_dates", "remove_outliers"]:
            action_type = "remove_nulls"
        return CleaningAction(
            task_id=task_id,
            action_type=action_type,
            column=data.get("column")
        )
    except Exception:
        if "date" in text.lower():
            return CleaningAction(task_id=task_id, action_type="fix_dates", column="hire_date")
        elif "outlier" in text.lower():
            return CleaningAction(task_id=task_id, action_type="remove_outliers", column="all")
        return CleaningAction(task_id=task_id, action_type="remove_nulls")


def heuristic_action(task_id: int, obs) -> CleaningAction:
    if obs.null_count > 0:
        return CleaningAction(task_id=task_id, action_type="remove_nulls")
    elif obs.date_format_errors > 0:
        return CleaningAction(task_id=task_id, action_type="fix_dates", column="hire_date")
    else:
        return CleaningAction(task_id=task_id, action_type="remove_outliers", column="all")


def run_episode(task_id: int, seed: int):
    env = DataCleaningEnvironment(task_id=task_id, seed=seed)
    obs = env.reset()
    error_str = "null"
    action = None

    user_msg = (
        f"Task {task_id}: {obs.task_description}\n"
        f"Nulls: {obs.null_count}, Date errors: {obs.date_format_errors}, "
        f"Outliers: {obs.outlier_count}\n"
        f"Preview:\n{obs.dataset_preview}\n"
        f"Respond with JSON only."
    )

    # ── Primary: LLM via OpenAI client ───────────────────────────────────────
    try:
        resp = client.chat.completions.create(
            model=MODEL_NAME,
            messages=[
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user",   "content": user_msg},
            ],
            max_tokens=100,
            temperature=0.1,
        )
        action = parse_llm_response(resp.choices[0].message.content, task_id)
    except Exception as e:
        error_str = str(e).replace("\n", " ")

    # ── Fallback: heuristic if LLM failed ────────────────────────────────────
    if action is None:
        action = heuristic_action(task_id, obs)

    col = action.column if action.column else "null"
    action_str = f"{action.action_type}('{col}')"

    _, reward, done, _ = env.step(action)
    if hasattr(env, "close"):
        env.close()

    return float(reward), action_str, bool(done), error_str


def main():
    all_results = {}
    n_episodes = int(os.getenv("N_EPISODES", "10"))

    for task_id in [1, 2, 3]:
        task_name = TASK_NAMES[task_id]
        print(f"[START] task={task_name} env={ENV_NAME} model={MODEL_NAME}", flush=True)

        episode_rewards = []
        success = False
        score = 0.0

        try:
            for seed in range(n_episodes):
                reward, action_str, done, error_str = run_episode(task_id, seed)
                episode_rewards.append(reward)
                print(
                    f"[STEP] step={seed + 1} action={action_str} "
                    f"reward={reward:.2f} done={str(done).lower()} error={error_str}",
                    flush=True,
                )

            score = sum(episode_rewards) / len(episode_rewards)
            score = round(min(max(score, 0.0), 1.0), 2)
            all_results[task_id] = score
            success = score > 0.0

        finally:
            rewards_str = ",".join(f"{r:.2f}" for r in episode_rewards)
            # ── [END] with score= field as required ──────────────────────────
            print(
                f"[END] success={str(success).lower()} "
                f"steps={len(episode_rewards)} "
                f"score={score:.2f} "
                f"rewards={rewards_str}",
                flush=True,
            )

    overall = round(sum(all_results.values()) / max(len(all_results), 1), 4)
    with open("scores.json", "w") as f:
        json.dump({"tasks": all_results, "overall": overall}, f, indent=2)
    print(f"[SUMMARY] overall_score={overall} task_scores={all_results}", flush=True)


if __name__ == "__main__":
    main()