data-cleaning-openenv / test_env.py
Dishaaa25's picture
Upload folder using huggingface_hub
c22bf49 verified
import json
from pathlib import Path
from fastapi.testclient import TestClient
from app import app
from env.environment import DataCleaningEnv
from env.graders import DataCleaningGrader
from env.models import Action
ROOT = Path(__file__).resolve().parent
def assert_invalid_action_consumes_step() -> None:
env = DataCleaningEnv("basic_cleaning")
obs = env.reset()
_, reward, _, info = env.step(
Action(action_type="convert_dtype", column="age", params={"target_dtype": "int"})
)
assert reward == 0.01
assert info["error"] == "invalid_action"
assert env.steps_remaining == obs.steps_remaining - 1
def assert_dependency_gate() -> None:
env = DataCleaningEnv("moderate_cleaning")
env.reset()
_, reward, _, info = env.step(
Action(action_type="convert_dtype", column="salary", params={"target_dtype": "int"})
)
assert reward == 0.01
assert info["error"] == "invalid_action"
def assert_api_contract() -> None:
client = TestClient(app)
root_response = client.get("/")
assert root_response.status_code == 200
assert root_response.json()["name"] == "data_cleaning_env"
assert client.get("/health").json()["status"] == "healthy"
metadata_response = client.get("/metadata")
assert metadata_response.status_code == 200
metadata_payload = metadata_response.json()
assert metadata_payload["name"] == "data_cleaning_env"
assert "description" in metadata_payload
schema_response = client.get("/schema")
assert schema_response.status_code == 200
schema_payload = schema_response.json()
assert {"action", "observation", "state"} <= set(schema_payload.keys())
reset_response = client.post("/reset", json={"task_name": "basic_cleaning"})
assert reset_response.status_code == 200
assert "pending_issues" in reset_response.json()
step_response = client.post(
"/step",
json={"action_type": "fill_missing", "column": "age", "params": {"strategy": "mean"}},
)
assert step_response.status_code == 200
assert {"observation", "reward", "done", "info"} <= set(step_response.json().keys())
state_response = client.get("/state")
assert state_response.status_code == 200
assert "quality_score" in state_response.json()
mcp_response = client.post("/mcp", json={"jsonrpc": "2.0", "id": "smoke"})
assert mcp_response.status_code == 200
assert mcp_response.json()["jsonrpc"] == "2.0"
def run_sequence(task_name: str, actions: list[Action], expected_issues: int) -> tuple[dict, float]:
env = DataCleaningEnv(task_name)
obs = env.reset()
assert len(obs.pending_issues) == expected_issues, (task_name, len(obs.pending_issues), expected_issues)
initial_quality = obs.quality_score
for action in actions:
obs, reward, done, info = env.step(action)
assert "error" not in info, (task_name, action, info)
if done:
break
assert obs.quality_score >= initial_quality
final_state = obs.model_dump()
config = json.loads((ROOT / "data" / f"{task_name}.json").read_text(encoding="utf-8"))
score = DataCleaningGrader().grade(
final_state,
{
"total_issues": expected_issues,
"max_steps": config["max_steps"],
},
)
return final_state, score
def main() -> None:
assert_invalid_action_consumes_step()
assert_dependency_gate()
assert_api_contract()
sequences = {
"basic_cleaning": (
[
Action(action_type="fill_missing", column="age", params={"strategy": "mean"}),
Action(action_type="fill_missing", column="salary", params={"strategy": "median"}),
],
2,
),
"moderate_cleaning": (
[
Action(action_type="fill_missing", column="age", params={"strategy": "mean"}),
Action(action_type="fill_missing", column="years_exp", params={"strategy": "median"}),
Action(action_type="fill_missing", column="salary", params={"strategy": "median"}),
Action(action_type="convert_dtype", column="salary", params={"target_dtype": "int"}),
Action(action_type="drop_duplicates", column="__all__", params={}),
],
5,
),
"full_pipeline": (
[
Action(action_type="fill_missing", column="age", params={"strategy": "mean"}),
Action(action_type="fill_missing", column="years_exp", params={"strategy": "median"}),
Action(action_type="fill_missing", column="rating", params={"strategy": "mean"}),
Action(action_type="fill_missing", column="salary", params={"strategy": "median"}),
Action(action_type="convert_dtype", column="salary", params={"target_dtype": "int"}),
Action(action_type="convert_dtype", column="rating", params={"target_dtype": "float"}),
Action(action_type="normalize_category", column="city", params={}),
Action(action_type="normalize_category", column="department", params={}),
Action(action_type="create_feature", column="age_group", params={"feature_name": "age_group"}),
Action(action_type="drop_duplicates", column="__all__", params={}),
],
10,
),
}
for task_name, (actions, expected_issues) in sequences.items():
final_state, score = run_sequence(task_name, actions, expected_issues)
pending = len(final_state["pending_issues"])
resolved = len(final_state["resolved_issues"])
print(
f"{task_name}: pending={pending} resolved={resolved} "
f"steps_remaining={final_state['steps_remaining']} grader_score={score}"
)
if __name__ == "__main__":
main()