Spaces:
Sleeping
Sleeping
File size: 5,537 Bytes
a03a89b 6951424 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | """Contract-style tests for MiniGridEnv reset/step/state behavior."""
from __future__ import annotations
import pytest
minigrid = pytest.importorskip("minigrid")
assert minigrid is not None
from MiniGridEnv.env.config import EnvConfig
from MiniGridEnv.env.minigrid_env import MiniGridEnvironment
from MiniGridEnv.env.models import MiniGridAction, MiniGridObservation
def test_reset_returns_valid_observation():
env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
obs = env.reset(seed=123)
assert isinstance(obs, MiniGridObservation)
assert obs.step_idx == 0
assert obs.history == []
assert obs.done is False
assert isinstance(obs.text, str) and len(obs.text.strip()) > 0
def test_observation_text_is_natural_language():
env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
obs = env.reset(seed=123)
assert "Mission:" in obs.text
assert "You are facing" in obs.text
def test_mission_is_populated():
env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
obs = env.reset(seed=123)
assert isinstance(obs.mission, str) and len(obs.mission.strip()) > 0
def test_step_accepts_valid_command():
env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
env.reset(seed=123)
obs = env.step(MiniGridAction(command="go forward"))
assert isinstance(obs, MiniGridObservation)
assert obs.step_idx == 1
def test_step_handles_invalid_command():
env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
env.reset(seed=123)
obs = env.step(MiniGridAction(command="fly away"))
assert isinstance(obs, MiniGridObservation)
assert obs.last_action == "go forward"
assert env.state.invalid_actions == 1
def test_episode_terminates_on_success():
try:
from minigrid.envs.babyai import BotAgent # type: ignore
except Exception:
pytest.skip("BotAgent unavailable for success-path test")
int_to_text = {
0: "turn left",
1: "turn right",
2: "go forward",
3: "pickup",
4: "drop",
5: "toggle",
6: "done",
}
env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
obs = env.reset(seed=123)
bot = BotAgent(env._gym_env.unwrapped) # type: ignore[attr-defined]
while not obs.done and obs.step_idx < obs.max_steps:
raw_obs = env._last_obs # type: ignore[attr-defined]
action = bot.act(raw_obs)
obs = env.step(MiniGridAction(command=int_to_text.get(int(action), "done")))
assert obs.done is True
assert env.state.completed is True
assert env.state.total_reward > 0.0
def test_episode_truncates_on_max_steps():
env = MiniGridEnvironment(
config=EnvConfig(level_name="GoToRedBall", max_steps_override=1)
)
env.reset(seed=123)
obs = env.step(MiniGridAction(command="go forward"))
assert obs.done is True
assert env.state.truncated is True
def test_deterministic_with_seed():
actions = ["turn left", "turn right", "go forward", "go forward"]
env_a = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
env_b = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
obs_a = env_a.reset(seed=777)
obs_b = env_b.reset(seed=777)
assert obs_a.text == obs_b.text
for command in actions:
if obs_a.done or obs_b.done:
assert obs_a.done == obs_b.done
break
obs_a = env_a.step(MiniGridAction(command=command))
obs_b = env_b.step(MiniGridAction(command=command))
assert obs_a.text == obs_b.text
assert obs_a.done == obs_b.done
def test_state_tracks_metrics():
env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
env.reset(seed=123)
env.step(MiniGridAction(command="go forward"))
state = env.state
assert state.steps_taken == 1
assert state.valid_actions == 1
assert state.action_distribution.get("go forward", 0) == 1
def test_all_seven_actions_accepted():
commands = [
"turn left",
"turn right",
"go forward",
"pickup",
"drop",
"toggle",
"done",
]
env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
for idx, command in enumerate(commands):
env.reset(seed=100 + idx)
obs = env.step(MiniGridAction(command=command))
assert isinstance(obs, MiniGridObservation)
def test_history_grows_each_step():
env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
env.reset(seed=123)
obs = env.step(MiniGridAction(command="go forward"))
assert len(obs.history) == 1
obs = env.step(MiniGridAction(command="turn left"))
assert len(obs.history) == 2
def test_different_levels_produce_different_missions():
env_a = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
env_b = MiniGridEnvironment(config=EnvConfig(level_name="GoToObj"))
obs_a = env_a.reset(seed=123)
obs_b = env_b.reset(seed=123)
assert obs_a.mission != obs_b.mission
def test_reset_level_kwarg_switches_level_on_same_env():
"""Remote PT clients can pass level/level_name in reset payload."""
env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
obs_ball = env.reset(seed=123)
assert obs_ball.level_name == "GoToRedBall"
obs_obj = env.reset(seed=123, level_name="GoToObj")
assert obs_obj.level_name == "GoToObj"
assert obs_obj.mission != obs_ball.mission
|