Spaces:
Sleeping
Sleeping
File size: 5,849 Bytes
dce68a7 c22bf49 dce68a7 c22bf49 dce68a7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | import json
from pathlib import Path
from fastapi.testclient import TestClient
from app import app
from env.environment import DataCleaningEnv
from env.graders import DataCleaningGrader
from env.models import Action
ROOT = Path(__file__).resolve().parent
def assert_invalid_action_consumes_step() -> None:
env = DataCleaningEnv("basic_cleaning")
obs = env.reset()
_, reward, _, info = env.step(
Action(action_type="convert_dtype", column="age", params={"target_dtype": "int"})
)
assert reward == 0.01
assert info["error"] == "invalid_action"
assert env.steps_remaining == obs.steps_remaining - 1
def assert_dependency_gate() -> None:
env = DataCleaningEnv("moderate_cleaning")
env.reset()
_, reward, _, info = env.step(
Action(action_type="convert_dtype", column="salary", params={"target_dtype": "int"})
)
assert reward == 0.01
assert info["error"] == "invalid_action"
def assert_api_contract() -> None:
client = TestClient(app)
root_response = client.get("/")
assert root_response.status_code == 200
assert root_response.json()["name"] == "data_cleaning_env"
assert client.get("/health").json()["status"] == "healthy"
metadata_response = client.get("/metadata")
assert metadata_response.status_code == 200
metadata_payload = metadata_response.json()
assert metadata_payload["name"] == "data_cleaning_env"
assert "description" in metadata_payload
schema_response = client.get("/schema")
assert schema_response.status_code == 200
schema_payload = schema_response.json()
assert {"action", "observation", "state"} <= set(schema_payload.keys())
reset_response = client.post("/reset", json={"task_name": "basic_cleaning"})
assert reset_response.status_code == 200
assert "pending_issues" in reset_response.json()
step_response = client.post(
"/step",
json={"action_type": "fill_missing", "column": "age", "params": {"strategy": "mean"}},
)
assert step_response.status_code == 200
assert {"observation", "reward", "done", "info"} <= set(step_response.json().keys())
state_response = client.get("/state")
assert state_response.status_code == 200
assert "quality_score" in state_response.json()
mcp_response = client.post("/mcp", json={"jsonrpc": "2.0", "id": "smoke"})
assert mcp_response.status_code == 200
assert mcp_response.json()["jsonrpc"] == "2.0"
def run_sequence(task_name: str, actions: list[Action], expected_issues: int) -> tuple[dict, float]:
env = DataCleaningEnv(task_name)
obs = env.reset()
assert len(obs.pending_issues) == expected_issues, (task_name, len(obs.pending_issues), expected_issues)
initial_quality = obs.quality_score
for action in actions:
obs, reward, done, info = env.step(action)
assert "error" not in info, (task_name, action, info)
if done:
break
assert obs.quality_score >= initial_quality
final_state = obs.model_dump()
config = json.loads((ROOT / "data" / f"{task_name}.json").read_text(encoding="utf-8"))
score = DataCleaningGrader().grade(
final_state,
{
"total_issues": expected_issues,
"max_steps": config["max_steps"],
},
)
return final_state, score
def main() -> None:
assert_invalid_action_consumes_step()
assert_dependency_gate()
assert_api_contract()
sequences = {
"basic_cleaning": (
[
Action(action_type="fill_missing", column="age", params={"strategy": "mean"}),
Action(action_type="fill_missing", column="salary", params={"strategy": "median"}),
],
2,
),
"moderate_cleaning": (
[
Action(action_type="fill_missing", column="age", params={"strategy": "mean"}),
Action(action_type="fill_missing", column="years_exp", params={"strategy": "median"}),
Action(action_type="fill_missing", column="salary", params={"strategy": "median"}),
Action(action_type="convert_dtype", column="salary", params={"target_dtype": "int"}),
Action(action_type="drop_duplicates", column="__all__", params={}),
],
5,
),
"full_pipeline": (
[
Action(action_type="fill_missing", column="age", params={"strategy": "mean"}),
Action(action_type="fill_missing", column="years_exp", params={"strategy": "median"}),
Action(action_type="fill_missing", column="rating", params={"strategy": "mean"}),
Action(action_type="fill_missing", column="salary", params={"strategy": "median"}),
Action(action_type="convert_dtype", column="salary", params={"target_dtype": "int"}),
Action(action_type="convert_dtype", column="rating", params={"target_dtype": "float"}),
Action(action_type="normalize_category", column="city", params={}),
Action(action_type="normalize_category", column="department", params={}),
Action(action_type="create_feature", column="age_group", params={"feature_name": "age_group"}),
Action(action_type="drop_duplicates", column="__all__", params={}),
],
10,
),
}
for task_name, (actions, expected_issues) in sequences.items():
final_state, score = run_sequence(task_name, actions, expected_issues)
pending = len(final_state["pending_issues"])
resolved = len(final_state["resolved_issues"])
print(
f"{task_name}: pending={pending} resolved={resolved} "
f"steps_remaining={final_state['steps_remaining']} grader_score={score}"
)
if __name__ == "__main__":
main()
|