data-cleaning-openenv / tests /test_environment.py
yashmarathe's picture
refactor: full openenv protocol compliance
1a55ff4
"""Tests for the core environment logic (openenv Environment interface)."""
from __future__ import annotations
import pytest
from data_cleaning_env.models import ActionType, CleaningAction, FillStrategy
from data_cleaning_env.server.environment import DataCleaningEnvironment
class TestEnvironmentReset:
def test_reset_returns_observation(self) -> None:
env = DataCleaningEnvironment()
obs = env.reset(task="easy")
assert obs.task == "easy"
assert obs.step == 0
assert obs.done is False
def test_unknown_task_raises(self) -> None:
env = DataCleaningEnvironment()
with pytest.raises(ValueError, match="Unknown task"):
env.reset(task="invalid")
def test_episode_stores_dirty_snapshot(self) -> None:
env = DataCleaningEnvironment()
obs = env.reset(task="easy")
ep = env._ep
assert "dirty_df" in ep
assert "initial_quality" in ep
class TestEnvironmentStep:
def test_fill_missing_increases_reward(self) -> None:
env = DataCleaningEnvironment()
obs = env.reset(task="easy")
# Find a column with missing values
col = None
for c, issues in obs.column_issues.items():
if isinstance(issues, dict):
if issues.get("missing_count", 0) > 0:
col = c
break
elif hasattr(issues, "missing_count") and issues.missing_count > 0:
col = c
break
assert col is not None, "Easy task should have missing values"
action = CleaningAction(
action_type=ActionType.fill_missing,
column=col,
strategy=FillStrategy.median,
)
obs2 = env.step(action)
assert obs2.step == 1
def test_done_ends_episode(self) -> None:
env = DataCleaningEnvironment()
env.reset(task="easy")
action = CleaningAction(action_type=ActionType.done)
obs2 = env.step(action)
assert obs2.done is True
def test_step_after_done_raises(self) -> None:
env = DataCleaningEnvironment()
env.reset(task="easy")
env.step(CleaningAction(action_type=ActionType.done))
with pytest.raises(ValueError, match="already done"):
env.step(CleaningAction(action_type=ActionType.done))
def test_step_without_reset_raises(self) -> None:
env = DataCleaningEnvironment()
with pytest.raises(ValueError, match="No active episode"):
env.step(CleaningAction(action_type=ActionType.done))
def test_profile_column_is_free(self) -> None:
env = DataCleaningEnvironment()
obs = env.reset(task="easy")
col = obs.columns[0]
action = CleaningAction(action_type=ActionType.profile_column, column=col)
obs2 = env.step(action)
assert obs2.profile_result is not None
assert env._ep["cost_used"] == 0.0
def test_observation_columns_reflect_current_state(self) -> None:
"""After a schema-changing action, obs.columns should reflect
the current DataFrame, not the clean reference."""
env = DataCleaningEnvironment()
obs = env.reset(task="easy")
col = obs.columns[0]
new_name = col + "_renamed"
action = CleaningAction(
action_type=ActionType.rename_column,
column=col,
new_name=new_name,
)
obs2 = env.step(action)
assert new_name in obs2.columns
assert col not in obs2.columns