MiniGridEnv / tests /test_contract.py
yashu2000's picture
Upload folder using huggingface_hub
6951424 verified
"""Contract-style tests for MiniGridEnv reset/step/state behavior."""
from __future__ import annotations
import pytest
minigrid = pytest.importorskip("minigrid")
assert minigrid is not None
from MiniGridEnv.env.config import EnvConfig
from MiniGridEnv.env.minigrid_env import MiniGridEnvironment
from MiniGridEnv.env.models import MiniGridAction, MiniGridObservation
def test_reset_returns_valid_observation():
env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
obs = env.reset(seed=123)
assert isinstance(obs, MiniGridObservation)
assert obs.step_idx == 0
assert obs.history == []
assert obs.done is False
assert isinstance(obs.text, str) and len(obs.text.strip()) > 0
def test_observation_text_is_natural_language():
env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
obs = env.reset(seed=123)
assert "Mission:" in obs.text
assert "You are facing" in obs.text
def test_mission_is_populated():
env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
obs = env.reset(seed=123)
assert isinstance(obs.mission, str) and len(obs.mission.strip()) > 0
def test_step_accepts_valid_command():
env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
env.reset(seed=123)
obs = env.step(MiniGridAction(command="go forward"))
assert isinstance(obs, MiniGridObservation)
assert obs.step_idx == 1
def test_step_handles_invalid_command():
env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
env.reset(seed=123)
obs = env.step(MiniGridAction(command="fly away"))
assert isinstance(obs, MiniGridObservation)
assert obs.last_action == "go forward"
assert env.state.invalid_actions == 1
def test_episode_terminates_on_success():
try:
from minigrid.envs.babyai import BotAgent # type: ignore
except Exception:
pytest.skip("BotAgent unavailable for success-path test")
int_to_text = {
0: "turn left",
1: "turn right",
2: "go forward",
3: "pickup",
4: "drop",
5: "toggle",
6: "done",
}
env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
obs = env.reset(seed=123)
bot = BotAgent(env._gym_env.unwrapped) # type: ignore[attr-defined]
while not obs.done and obs.step_idx < obs.max_steps:
raw_obs = env._last_obs # type: ignore[attr-defined]
action = bot.act(raw_obs)
obs = env.step(MiniGridAction(command=int_to_text.get(int(action), "done")))
assert obs.done is True
assert env.state.completed is True
assert env.state.total_reward > 0.0
def test_episode_truncates_on_max_steps():
env = MiniGridEnvironment(
config=EnvConfig(level_name="GoToRedBall", max_steps_override=1)
)
env.reset(seed=123)
obs = env.step(MiniGridAction(command="go forward"))
assert obs.done is True
assert env.state.truncated is True
def test_deterministic_with_seed():
actions = ["turn left", "turn right", "go forward", "go forward"]
env_a = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
env_b = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
obs_a = env_a.reset(seed=777)
obs_b = env_b.reset(seed=777)
assert obs_a.text == obs_b.text
for command in actions:
if obs_a.done or obs_b.done:
assert obs_a.done == obs_b.done
break
obs_a = env_a.step(MiniGridAction(command=command))
obs_b = env_b.step(MiniGridAction(command=command))
assert obs_a.text == obs_b.text
assert obs_a.done == obs_b.done
def test_state_tracks_metrics():
env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
env.reset(seed=123)
env.step(MiniGridAction(command="go forward"))
state = env.state
assert state.steps_taken == 1
assert state.valid_actions == 1
assert state.action_distribution.get("go forward", 0) == 1
def test_all_seven_actions_accepted():
commands = [
"turn left",
"turn right",
"go forward",
"pickup",
"drop",
"toggle",
"done",
]
env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
for idx, command in enumerate(commands):
env.reset(seed=100 + idx)
obs = env.step(MiniGridAction(command=command))
assert isinstance(obs, MiniGridObservation)
def test_history_grows_each_step():
env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
env.reset(seed=123)
obs = env.step(MiniGridAction(command="go forward"))
assert len(obs.history) == 1
obs = env.step(MiniGridAction(command="turn left"))
assert len(obs.history) == 2
def test_different_levels_produce_different_missions():
env_a = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
env_b = MiniGridEnvironment(config=EnvConfig(level_name="GoToObj"))
obs_a = env_a.reset(seed=123)
obs_b = env_b.reset(seed=123)
assert obs_a.mission != obs_b.mission
def test_reset_level_kwarg_switches_level_on_same_env():
"""Remote PT clients can pass level/level_name in reset payload."""
env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
obs_ball = env.reset(seed=123)
assert obs_ball.level_name == "GoToRedBall"
obs_obj = env.reset(seed=123, level_name="GoToObj")
assert obs_obj.level_name == "GoToObj"
assert obs_obj.mission != obs_ball.mission