| import json |
| from pathlib import Path |
|
|
| from fastapi.testclient import TestClient |
|
|
| from app import app |
| from env.environment import DataCleaningEnv |
| from env.graders import DataCleaningGrader |
| from env.models import Action |
|
|
| ROOT = Path(__file__).resolve().parent |
|
|
|
|
| def assert_invalid_action_consumes_step() -> None: |
| env = DataCleaningEnv("basic_cleaning") |
| obs = env.reset() |
| _, reward, _, info = env.step( |
| Action(action_type="convert_dtype", column="age", params={"target_dtype": "int"}) |
| ) |
| assert reward == 0.01 |
| assert info["error"] == "invalid_action" |
| assert env.steps_remaining == obs.steps_remaining - 1 |
|
|
|
|
| def assert_dependency_gate() -> None: |
| env = DataCleaningEnv("moderate_cleaning") |
| env.reset() |
| _, reward, _, info = env.step( |
| Action(action_type="convert_dtype", column="salary", params={"target_dtype": "int"}) |
| ) |
| assert reward == 0.01 |
| assert info["error"] == "invalid_action" |
|
|
|
|
| def assert_api_contract() -> None: |
| client = TestClient(app) |
|
|
| root_response = client.get("/") |
| assert root_response.status_code == 200 |
| assert root_response.json()["name"] == "data_cleaning_env" |
|
|
| assert client.get("/health").json()["status"] == "healthy" |
|
|
| metadata_response = client.get("/metadata") |
| assert metadata_response.status_code == 200 |
| metadata_payload = metadata_response.json() |
| assert metadata_payload["name"] == "data_cleaning_env" |
| assert "description" in metadata_payload |
|
|
| schema_response = client.get("/schema") |
| assert schema_response.status_code == 200 |
| schema_payload = schema_response.json() |
| assert {"action", "observation", "state"} <= set(schema_payload.keys()) |
|
|
| reset_response = client.post("/reset", json={"task_name": "basic_cleaning"}) |
| assert reset_response.status_code == 200 |
| assert "pending_issues" in reset_response.json() |
|
|
| step_response = client.post( |
| "/step", |
| json={"action_type": "fill_missing", "column": "age", "params": {"strategy": "mean"}}, |
| ) |
| assert step_response.status_code == 200 |
| assert {"observation", "reward", "done", "info"} <= set(step_response.json().keys()) |
|
|
| state_response = client.get("/state") |
| assert state_response.status_code == 200 |
| assert "quality_score" in state_response.json() |
|
|
| mcp_response = client.post("/mcp", json={"jsonrpc": "2.0", "id": "smoke"}) |
| assert mcp_response.status_code == 200 |
| assert mcp_response.json()["jsonrpc"] == "2.0" |
|
|
|
|
| def run_sequence(task_name: str, actions: list[Action], expected_issues: int) -> tuple[dict, float]: |
| env = DataCleaningEnv(task_name) |
| obs = env.reset() |
| assert len(obs.pending_issues) == expected_issues, (task_name, len(obs.pending_issues), expected_issues) |
| initial_quality = obs.quality_score |
|
|
| for action in actions: |
| obs, reward, done, info = env.step(action) |
| assert "error" not in info, (task_name, action, info) |
| if done: |
| break |
|
|
| assert obs.quality_score >= initial_quality |
| final_state = obs.model_dump() |
| config = json.loads((ROOT / "data" / f"{task_name}.json").read_text(encoding="utf-8")) |
| score = DataCleaningGrader().grade( |
| final_state, |
| { |
| "total_issues": expected_issues, |
| "max_steps": config["max_steps"], |
| }, |
| ) |
| return final_state, score |
|
|
|
|
| def main() -> None: |
| assert_invalid_action_consumes_step() |
| assert_dependency_gate() |
| assert_api_contract() |
|
|
| sequences = { |
| "basic_cleaning": ( |
| [ |
| Action(action_type="fill_missing", column="age", params={"strategy": "mean"}), |
| Action(action_type="fill_missing", column="salary", params={"strategy": "median"}), |
| ], |
| 2, |
| ), |
| "moderate_cleaning": ( |
| [ |
| Action(action_type="fill_missing", column="age", params={"strategy": "mean"}), |
| Action(action_type="fill_missing", column="years_exp", params={"strategy": "median"}), |
| Action(action_type="fill_missing", column="salary", params={"strategy": "median"}), |
| Action(action_type="convert_dtype", column="salary", params={"target_dtype": "int"}), |
| Action(action_type="drop_duplicates", column="__all__", params={}), |
| ], |
| 5, |
| ), |
| "full_pipeline": ( |
| [ |
| Action(action_type="fill_missing", column="age", params={"strategy": "mean"}), |
| Action(action_type="fill_missing", column="years_exp", params={"strategy": "median"}), |
| Action(action_type="fill_missing", column="rating", params={"strategy": "mean"}), |
| Action(action_type="fill_missing", column="salary", params={"strategy": "median"}), |
| Action(action_type="convert_dtype", column="salary", params={"target_dtype": "int"}), |
| Action(action_type="convert_dtype", column="rating", params={"target_dtype": "float"}), |
| Action(action_type="normalize_category", column="city", params={}), |
| Action(action_type="normalize_category", column="department", params={}), |
| Action(action_type="create_feature", column="age_group", params={"feature_name": "age_group"}), |
| Action(action_type="drop_duplicates", column="__all__", params={}), |
| ], |
| 10, |
| ), |
| } |
|
|
| for task_name, (actions, expected_issues) in sequences.items(): |
| final_state, score = run_sequence(task_name, actions, expected_issues) |
| pending = len(final_state["pending_issues"]) |
| resolved = len(final_state["resolved_issues"]) |
| print( |
| f"{task_name}: pending={pending} resolved={resolved} " |
| f"steps_remaining={final_state['steps_remaining']} grader_score={score}" |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|