Spaces:
Sleeping
Sleeping
| """Contract-style tests for MiniGridEnv reset/step/state behavior.""" | |
| from __future__ import annotations | |
| import pytest | |
| minigrid = pytest.importorskip("minigrid") | |
| assert minigrid is not None | |
| from MiniGridEnv.env.config import EnvConfig | |
| from MiniGridEnv.env.minigrid_env import MiniGridEnvironment | |
| from MiniGridEnv.env.models import MiniGridAction, MiniGridObservation | |
| def test_reset_returns_valid_observation(): | |
| env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall")) | |
| obs = env.reset(seed=123) | |
| assert isinstance(obs, MiniGridObservation) | |
| assert obs.step_idx == 0 | |
| assert obs.history == [] | |
| assert obs.done is False | |
| assert isinstance(obs.text, str) and len(obs.text.strip()) > 0 | |
| def test_observation_text_is_natural_language(): | |
| env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall")) | |
| obs = env.reset(seed=123) | |
| assert "Mission:" in obs.text | |
| assert "You are facing" in obs.text | |
| def test_mission_is_populated(): | |
| env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall")) | |
| obs = env.reset(seed=123) | |
| assert isinstance(obs.mission, str) and len(obs.mission.strip()) > 0 | |
| def test_step_accepts_valid_command(): | |
| env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall")) | |
| env.reset(seed=123) | |
| obs = env.step(MiniGridAction(command="go forward")) | |
| assert isinstance(obs, MiniGridObservation) | |
| assert obs.step_idx == 1 | |
| def test_step_handles_invalid_command(): | |
| env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall")) | |
| env.reset(seed=123) | |
| obs = env.step(MiniGridAction(command="fly away")) | |
| assert isinstance(obs, MiniGridObservation) | |
| assert obs.last_action == "go forward" | |
| assert env.state.invalid_actions == 1 | |
| def test_episode_terminates_on_success(): | |
| try: | |
| from minigrid.envs.babyai import BotAgent # type: ignore | |
| except Exception: | |
| pytest.skip("BotAgent unavailable for success-path test") | |
| int_to_text = { | |
| 0: "turn left", | |
| 1: "turn right", | |
| 2: "go forward", | |
| 3: "pickup", | |
| 4: "drop", | |
| 5: "toggle", | |
| 6: "done", | |
| } | |
| env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall")) | |
| obs = env.reset(seed=123) | |
| bot = BotAgent(env._gym_env.unwrapped) # type: ignore[attr-defined] | |
| while not obs.done and obs.step_idx < obs.max_steps: | |
| raw_obs = env._last_obs # type: ignore[attr-defined] | |
| action = bot.act(raw_obs) | |
| obs = env.step(MiniGridAction(command=int_to_text.get(int(action), "done"))) | |
| assert obs.done is True | |
| assert env.state.completed is True | |
| assert env.state.total_reward > 0.0 | |
| def test_episode_truncates_on_max_steps(): | |
| env = MiniGridEnvironment( | |
| config=EnvConfig(level_name="GoToRedBall", max_steps_override=1) | |
| ) | |
| env.reset(seed=123) | |
| obs = env.step(MiniGridAction(command="go forward")) | |
| assert obs.done is True | |
| assert env.state.truncated is True | |
| def test_deterministic_with_seed(): | |
| actions = ["turn left", "turn right", "go forward", "go forward"] | |
| env_a = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall")) | |
| env_b = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall")) | |
| obs_a = env_a.reset(seed=777) | |
| obs_b = env_b.reset(seed=777) | |
| assert obs_a.text == obs_b.text | |
| for command in actions: | |
| if obs_a.done or obs_b.done: | |
| assert obs_a.done == obs_b.done | |
| break | |
| obs_a = env_a.step(MiniGridAction(command=command)) | |
| obs_b = env_b.step(MiniGridAction(command=command)) | |
| assert obs_a.text == obs_b.text | |
| assert obs_a.done == obs_b.done | |
| def test_state_tracks_metrics(): | |
| env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall")) | |
| env.reset(seed=123) | |
| env.step(MiniGridAction(command="go forward")) | |
| state = env.state | |
| assert state.steps_taken == 1 | |
| assert state.valid_actions == 1 | |
| assert state.action_distribution.get("go forward", 0) == 1 | |
| def test_all_seven_actions_accepted(): | |
| commands = [ | |
| "turn left", | |
| "turn right", | |
| "go forward", | |
| "pickup", | |
| "drop", | |
| "toggle", | |
| "done", | |
| ] | |
| env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall")) | |
| for idx, command in enumerate(commands): | |
| env.reset(seed=100 + idx) | |
| obs = env.step(MiniGridAction(command=command)) | |
| assert isinstance(obs, MiniGridObservation) | |
| def test_history_grows_each_step(): | |
| env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall")) | |
| env.reset(seed=123) | |
| obs = env.step(MiniGridAction(command="go forward")) | |
| assert len(obs.history) == 1 | |
| obs = env.step(MiniGridAction(command="turn left")) | |
| assert len(obs.history) == 2 | |
| def test_different_levels_produce_different_missions(): | |
| env_a = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall")) | |
| env_b = MiniGridEnvironment(config=EnvConfig(level_name="GoToObj")) | |
| obs_a = env_a.reset(seed=123) | |
| obs_b = env_b.reset(seed=123) | |
| assert obs_a.mission != obs_b.mission | |
| def test_reset_level_kwarg_switches_level_on_same_env(): | |
| """Remote PT clients can pass level/level_name in reset payload.""" | |
| env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall")) | |
| obs_ball = env.reset(seed=123) | |
| assert obs_ball.level_name == "GoToRedBall" | |
| obs_obj = env.reset(seed=123, level_name="GoToObj") | |
| assert obs_obj.level_name == "GoToObj" | |
| assert obs_obj.mission != obs_ball.mission | |