File size: 5,849 Bytes
dce68a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c22bf49
dce68a7
 
 
 
 
 
 
 
 
 
c22bf49
dce68a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import json
from pathlib import Path

from fastapi.testclient import TestClient

from app import app
from env.environment import DataCleaningEnv
from env.graders import DataCleaningGrader
from env.models import Action

ROOT = Path(__file__).resolve().parent


def assert_invalid_action_consumes_step() -> None:
    env = DataCleaningEnv("basic_cleaning")
    obs = env.reset()
    _, reward, _, info = env.step(
        Action(action_type="convert_dtype", column="age", params={"target_dtype": "int"})
    )
    assert reward == 0.01
    assert info["error"] == "invalid_action"
    assert env.steps_remaining == obs.steps_remaining - 1


def assert_dependency_gate() -> None:
    env = DataCleaningEnv("moderate_cleaning")
    env.reset()
    _, reward, _, info = env.step(
        Action(action_type="convert_dtype", column="salary", params={"target_dtype": "int"})
    )
    assert reward == 0.01
    assert info["error"] == "invalid_action"


def assert_api_contract() -> None:
    client = TestClient(app)

    root_response = client.get("/")
    assert root_response.status_code == 200
    assert root_response.json()["name"] == "data_cleaning_env"

    assert client.get("/health").json()["status"] == "healthy"

    metadata_response = client.get("/metadata")
    assert metadata_response.status_code == 200
    metadata_payload = metadata_response.json()
    assert metadata_payload["name"] == "data_cleaning_env"
    assert "description" in metadata_payload

    schema_response = client.get("/schema")
    assert schema_response.status_code == 200
    schema_payload = schema_response.json()
    assert {"action", "observation", "state"} <= set(schema_payload.keys())

    reset_response = client.post("/reset", json={"task_name": "basic_cleaning"})
    assert reset_response.status_code == 200
    assert "pending_issues" in reset_response.json()

    step_response = client.post(
        "/step",
        json={"action_type": "fill_missing", "column": "age", "params": {"strategy": "mean"}},
    )
    assert step_response.status_code == 200
    assert {"observation", "reward", "done", "info"} <= set(step_response.json().keys())

    state_response = client.get("/state")
    assert state_response.status_code == 200
    assert "quality_score" in state_response.json()

    mcp_response = client.post("/mcp", json={"jsonrpc": "2.0", "id": "smoke"})
    assert mcp_response.status_code == 200
    assert mcp_response.json()["jsonrpc"] == "2.0"


def run_sequence(task_name: str, actions: list[Action], expected_issues: int) -> tuple[dict, float]:
    env = DataCleaningEnv(task_name)
    obs = env.reset()
    assert len(obs.pending_issues) == expected_issues, (task_name, len(obs.pending_issues), expected_issues)
    initial_quality = obs.quality_score

    for action in actions:
        obs, reward, done, info = env.step(action)
        assert "error" not in info, (task_name, action, info)
        if done:
            break

    assert obs.quality_score >= initial_quality
    final_state = obs.model_dump()
    config = json.loads((ROOT / "data" / f"{task_name}.json").read_text(encoding="utf-8"))
    score = DataCleaningGrader().grade(
        final_state,
        {
            "total_issues": expected_issues,
            "max_steps": config["max_steps"],
        },
    )
    return final_state, score


def main() -> None:
    assert_invalid_action_consumes_step()
    assert_dependency_gate()
    assert_api_contract()

    sequences = {
        "basic_cleaning": (
            [
                Action(action_type="fill_missing", column="age", params={"strategy": "mean"}),
                Action(action_type="fill_missing", column="salary", params={"strategy": "median"}),
            ],
            2,
        ),
        "moderate_cleaning": (
            [
                Action(action_type="fill_missing", column="age", params={"strategy": "mean"}),
                Action(action_type="fill_missing", column="years_exp", params={"strategy": "median"}),
                Action(action_type="fill_missing", column="salary", params={"strategy": "median"}),
                Action(action_type="convert_dtype", column="salary", params={"target_dtype": "int"}),
                Action(action_type="drop_duplicates", column="__all__", params={}),
            ],
            5,
        ),
        "full_pipeline": (
            [
                Action(action_type="fill_missing", column="age", params={"strategy": "mean"}),
                Action(action_type="fill_missing", column="years_exp", params={"strategy": "median"}),
                Action(action_type="fill_missing", column="rating", params={"strategy": "mean"}),
                Action(action_type="fill_missing", column="salary", params={"strategy": "median"}),
                Action(action_type="convert_dtype", column="salary", params={"target_dtype": "int"}),
                Action(action_type="convert_dtype", column="rating", params={"target_dtype": "float"}),
                Action(action_type="normalize_category", column="city", params={}),
                Action(action_type="normalize_category", column="department", params={}),
                Action(action_type="create_feature", column="age_group", params={"feature_name": "age_group"}),
                Action(action_type="drop_duplicates", column="__all__", params={}),
            ],
            10,
        ),
    }

    for task_name, (actions, expected_issues) in sequences.items():
        final_state, score = run_sequence(task_name, actions, expected_issues)
        pending = len(final_state["pending_issues"])
        resolved = len(final_state["resolved_issues"])
        print(
            f"{task_name}: pending={pending} resolved={resolved} "
            f"steps_remaining={final_state['steps_remaining']} grader_score={score}"
        )


if __name__ == "__main__":
    main()