File size: 9,615 Bytes
ac326a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c49ee77
ac326a6
 
 
 
 
 
 
 
 
 
 
7c2c5f2
ac326a6
 
7c2c5f2
 
 
 
ac326a6
 
 
7c2c5f2
 
ac326a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20312e4
ac326a6
20312e4
ac326a6
 
 
 
 
 
 
 
 
7c2c5f2
 
 
 
 
 
 
 
 
ac326a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c2c5f2
 
 
 
ac326a6
 
 
 
 
 
 
 
 
 
 
 
 
 
20312e4
ac326a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20312e4
 
ac326a6
 
 
 
 
 
 
 
20312e4
ac326a6
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
"""Submission inference runner for CleanOps OpenEnv."""

from __future__ import annotations

import json
import os
from pathlib import Path
import sys
import textwrap
from typing import Any

from openai import OpenAI

PROJECT_ROOT = Path(__file__).resolve().parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from cleanops_env import CleanOpsEnvClient, DataCleaningAction, LocalCleanOpsEnv
from cleanops_env.models import DataCleaningObservation
from cleanops_env.tasks import list_task_ids

API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
HF_TOKEN = os.getenv("HF_TOKEN")
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
TASK_NAME = os.getenv("TASK_NAME", "all")
BENCHMARK = os.getenv("BENCHMARK", "cleanops_env")
MAX_STEPS = int(os.getenv("MAX_STEPS", "18"))
SUCCESS_SCORE_THRESHOLD = float(os.getenv("SUCCESS_SCORE_THRESHOLD", "0.95"))

SYSTEM_PROMPT = textwrap.dedent(
    """
    You are a data-cleaning operations agent working in the CleanOps OpenEnv benchmark.
    Choose exactly one JSON action per turn using this schema:
    {
      "action_type": "inspect_table" | "inspect_operation" | "apply_operation" | "request_review" | "run_sync_dry_run" | "submit",
      "table_name": string | null,
      "operation_id": string | null,
      "entity_type": string | null,
      "entity_id": string | null,
      "target_system": "crm" | "billing" | null,
      "reason_code": string | null,
      "reasoning": string
    }
    Prefer safe/review operations that directly resolve current validation issues.
    Use request_review when the environment flags an ambiguous merge or repair decision.
    Use run_sync_dry_run before submit on medium and hard tasks when downstream risk still looks material.
    Avoid destructive operations unless the task objective explicitly asks for deletions.
    Submit once quality_score is high and remaining validation issues are gone.
    Return only a single JSON object.
    """
).strip()


def log_start(task: str, env: str, model: str) -> None:
    print(f"[START] task={task} env={env} model={model}", flush=True)


def log_step(step: int, action: str, reward: float, done: bool, error: str | None) -> None:
    safe_action = action.replace("\n", " ").replace("\r", " ").strip()
    safe_error = error.replace("\n", " ").replace("\r", " ").strip() if error else "null"
    print(f"[STEP] step={step} action={safe_action} reward={reward:.2f} done={str(done).lower()} error={safe_error}", flush=True)


def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None:
    rewards_str = ",".join(f"{reward:.2f}" for reward in rewards)
    print(f"[END] success={str(success).lower()} steps={steps} score={score:.2f} rewards={rewards_str}", flush=True)


def build_observation_prompt(observation: DataCleaningObservation) -> str:
    payload = {
        "task_id": observation.task_id,
        "difficulty": observation.difficulty,
        "objective": observation.objective,
        "quality_score": observation.quality_score,
        "remaining_steps": observation.remaining_steps,
        "review_budget_remaining": observation.review_budget_remaining,
        "supported_sync_targets": observation.supported_sync_targets,
        "downstream_health": observation.downstream_health.model_dump(),
        "risk_cards": [risk_card.model_dump() for risk_card in observation.risk_cards],
        "available_review_targets": [target.model_dump() for target in observation.available_review_targets],
        "pending_reviews": [review.model_dump() for review in observation.pending_reviews],
        "resolved_reviews": [review.model_dump() for review in observation.resolved_reviews],
        "last_dry_run": observation.last_dry_run.model_dump() if observation.last_dry_run else None,
        "action_costs": [entry.model_dump() for entry in observation.action_costs],
        "table_summaries": [summary.model_dump() for summary in observation.table_summaries],
        "focus_table": observation.focus_table.model_dump() if observation.focus_table else None,
        "focus_operation": observation.focus_operation.model_dump() if observation.focus_operation else None,
        "available_operations": [operation.model_dump() for operation in observation.available_operations],
        "validation_issues": [issue.model_dump() for issue in observation.validation_issues],
        "issue_cards": [issue_card.model_dump() for issue_card in observation.issue_cards],
        "recent_history": observation.recent_history,
        "last_action_status": observation.last_action_status,
        "last_action_error": observation.last_action_error,
        "grader": observation.grader.model_dump(),
    }
    return json.dumps(payload, separators=(",", ":"))


def fallback_action(observation: DataCleaningObservation) -> DataCleaningAction:
    for issue_card in observation.issue_cards:
        for operation_id in issue_card.recommended_operation_ids:
            operation = next((candidate for candidate in observation.available_operations if candidate.operation_id == operation_id), None)
            if operation and not operation.already_applied and operation.risk != "destructive":
                return DataCleaningAction(action_type="apply_operation", operation_id=operation.operation_id, reasoning=f"Apply recommended operation {operation.operation_id}.")
    for operation in observation.available_operations:
        if not operation.already_applied and operation.risk != "destructive":
            return DataCleaningAction(action_type="apply_operation", operation_id=operation.operation_id, reasoning=f"Apply next safe operation {operation.operation_id}.")
    return DataCleaningAction(action_type="submit", reasoning="Submit after exhausting all safe non-destructive operations.")


def choose_action(client: OpenAI | None, observation: DataCleaningObservation) -> DataCleaningAction:
    if observation.remaining_steps <= 1 and not observation.validation_issues:
        return DataCleaningAction(action_type="submit", reasoning="Submit on final clean step.")
    if client is None:
        return fallback_action(observation)
    try:
        completion = client.chat.completions.create(
            model=MODEL_NAME,
            messages=[{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": build_observation_prompt(observation)}],
            temperature=0.0,
            max_tokens=256,
            stream=False,
        )
        content = (completion.choices[0].message.content or "").strip()
        action_payload = json.loads(content)
        return DataCleaningAction.model_validate(action_payload)
    except Exception:
        return fallback_action(observation)


def action_to_string(action: DataCleaningAction) -> str:
    if action.action_type == "inspect_table":
        return f"inspect_table({action.table_name})"
    if action.action_type == "inspect_operation":
        return f"inspect_operation({action.operation_id})"
    if action.action_type == "apply_operation":
        return f"apply_operation({action.operation_id})"
    if action.action_type == "request_review":
        return f"request_review({action.entity_type},{action.entity_id},{action.reason_code})"
    if action.action_type == "run_sync_dry_run":
        return f"run_sync_dry_run({action.target_system})"
    return "submit()"


def create_env() -> Any:
    if LOCAL_IMAGE_NAME:
        return CleanOpsEnvClient.from_docker_image(LOCAL_IMAGE_NAME)
    return LocalCleanOpsEnv()


def run_episode(task_name: str) -> None:
    env = None
    rewards: list[float] = []
    steps_taken = 0
    success = False
    final_score = 0.0
    log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
    try:
        env = create_env()
        result = env.reset(task_id=task_name, seed=7)
        observation = result.observation if hasattr(result, "observation") else result
        client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN or "EMPTY", timeout=30.0) if HF_TOKEN else None
        for step in range(1, MAX_STEPS + 1):
            if observation.done:
                break
            action = choose_action(client, observation)
            step_result = env.step(action)
            if isinstance(step_result, tuple):
                observation, reward, done, info = step_result
                error = info.get("last_action_error")
            else:
                observation = step_result.observation
                reward = float(step_result.reward or 0.0)
                done = bool(step_result.done)
                error = observation.last_action_error
            rewards.append(float(reward))
            steps_taken = step
            log_step(step=step, action=action_to_string(action), reward=float(reward), done=bool(done), error=error)
            if done:
                break
        final_score = float(observation.quality_score)
        success = final_score >= SUCCESS_SCORE_THRESHOLD and observation.done
    except Exception as exc:
        log_step(step=max(1, steps_taken + 1), action="submit()", reward=0.0, done=True, error=str(exc))
    finally:
        if env is not None:
            try:
                env.close()
            except Exception:
                pass
        log_end(success=success, steps=steps_taken, score=final_score, rewards=rewards)


def main() -> None:
    task_names = list_task_ids() if TASK_NAME == "all" else [TASK_NAME]
    for task_name in task_names:
        run_episode(task_name)


if __name__ == "__main__":
    main()