File size: 7,179 Bytes
ac326a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c2c5f2
ac326a6
 
7c2c5f2
 
 
 
ac326a6
 
 
 
7c2c5f2
 
ac326a6
 
 
 
 
 
 
 
 
 
 
 
 
 
7c2c5f2
 
 
 
 
 
 
 
 
ac326a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""OpenAI baseline agent for CleanOps."""

from __future__ import annotations

import argparse
import json
import os
from pathlib import Path
import sys
from typing import Any

from openai import OpenAI

PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from cleanops_env.local_env import LocalCleanOpsEnv
from cleanops_env.models import DataCleaningAction, DataCleaningObservation
from cleanops_env.tasks import list_task_ids

SYSTEM_PROMPT = """You are a careful data-cleaning operations agent.
Your job is to improve the current task score by choosing one JSON action at a time.
Use only this JSON schema:
{
  "action_type": "inspect_table" | "inspect_operation" | "apply_operation" | "request_review" | "run_sync_dry_run" | "submit",
  "table_name": string | null,
  "operation_id": string | null,
  "entity_type": string | null,
  "entity_id": string | null,
  "target_system": "crm" | "billing" | null,
  "reason_code": string | null,
  "reasoning": string
}
Rules:
- Prefer safe/review operations that directly address unresolved validation issues.
- Use request_review when an ambiguous merge or foreign-key repair needs confirmation.
- Use run_sync_dry_run before submit when downstream health is still weak.
- Avoid destructive operations unless the objective explicitly asks for row deletion.
- Call submit only when the data looks clean or there is 1 step left.
- Return a single JSON object and no extra text."""


def compact_observation(observation: DataCleaningObservation) -> dict[str, Any]:
    return {
        "task_id": observation.task_id,
        "task_title": observation.task_title,
        "difficulty": observation.difficulty,
        "objective": observation.objective,
        "dataset_context": observation.dataset_context,
        "quality_score": observation.quality_score,
        "remaining_steps": observation.remaining_steps,
        "review_budget_remaining": observation.review_budget_remaining,
        "supported_sync_targets": observation.supported_sync_targets,
        "downstream_health": observation.downstream_health.model_dump(),
        "risk_cards": [risk_card.model_dump() for risk_card in observation.risk_cards],
        "available_review_targets": [target.model_dump() for target in observation.available_review_targets],
        "pending_reviews": [review.model_dump() for review in observation.pending_reviews],
        "resolved_reviews": [review.model_dump() for review in observation.resolved_reviews],
        "last_dry_run": observation.last_dry_run.model_dump() if observation.last_dry_run else None,
        "action_costs": [entry.model_dump() for entry in observation.action_costs],
        "last_action_status": observation.last_action_status,
        "recent_history": observation.recent_history[-5:],
        "table_summaries": [summary.model_dump() for summary in observation.table_summaries],
        "focus_table": observation.focus_table.model_dump() if observation.focus_table else None,
        "focus_operation": observation.focus_operation.model_dump() if observation.focus_operation else None,
        "available_operations": [operation.model_dump() for operation in observation.available_operations],
        "validation_issues": [issue.model_dump() for issue in observation.validation_issues],
        "issue_cards": [issue_card.model_dump() for issue_card in observation.issue_cards],
        "grader": observation.grader.model_dump(),
    }


def choose_action(client: OpenAI, model: str, seed: int, observation: DataCleaningObservation) -> DataCleaningAction:
    payload = compact_observation(observation)
    request_kwargs = {
        "model": model,
        "temperature": 0,
        "messages": [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": json.dumps(payload)},
        ],
        "response_format": {"type": "json_object"},
    }
    try:
        request_kwargs["seed"] = seed
        response = client.chat.completions.create(**request_kwargs)
    except TypeError:
        request_kwargs.pop("seed", None)
        response = client.chat.completions.create(**request_kwargs)

    content = response.choices[0].message.content or "{}"
    try:
        action_payload = json.loads(content)
        return DataCleaningAction.model_validate(action_payload)
    except Exception:
        fallback_operation = next((op.operation_id for op in observation.available_operations if not op.already_applied and op.risk != "destructive"), None)
        if observation.remaining_steps <= 1 or fallback_operation is None:
            return DataCleaningAction(action_type="submit", reasoning="Fallback submit because model output could not be parsed.")
        return DataCleaningAction(action_type="apply_operation", operation_id=fallback_operation, reasoning="Fallback safe operation after parse failure.")


def run_baseline(model: str, seed: int) -> dict[str, Any]:
    if not os.environ.get("OPENAI_API_KEY"):
        raise RuntimeError("OPENAI_API_KEY is not set. Export it before running this baseline.")

    openai_client = OpenAI()
    env = LocalCleanOpsEnv()
    results = []
    for task_id in list_task_ids():
        observation = env.reset(task_id=task_id, seed=seed)
        done = observation.done
        total_reward = 0.0
        step_count = 0
        trajectory = []
        while not done:
            action = choose_action(openai_client, model, seed + step_count, observation)
            observation, reward, done, info = env.step(action)
            total_reward += reward
            step_count += 1
            trajectory.append(
                {
                    "action": action.model_dump(),
                    "reward": reward,
                    "score": observation.quality_score,
                    "done": done,
                    "status": info["last_action_status"],
                }
            )
            if step_count >= 32:
                break
        results.append(
            {
                "task_id": task_id,
                "final_score": observation.quality_score,
                "grader": observation.grader.model_dump(),
                "steps": step_count,
                "total_reward": round(total_reward, 4),
                "trajectory": trajectory,
            }
        )
    return {
        "agent": "openai_chat_completions",
        "model": model,
        "seed": seed,
        "tasks": results,
        "mean_score": round(sum(item["final_score"] for item in results) / len(results), 4),
    }


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", default=os.environ.get("OPENAI_MODEL", "gpt-4.1-mini"))
    parser.add_argument("--seed", type=int, default=int(os.environ.get("OPENAI_SEED", "7")))
    parser.add_argument("--output", type=str, default="")
    args = parser.parse_args()
    report = run_baseline(model=args.model, seed=args.seed)
    rendered = json.dumps(report, indent=2)
    print(rendered)
    if args.output:
        Path(args.output).write_text(rendered + "\n", encoding="utf-8")


if __name__ == "__main__":
    main()