Spaces:
Sleeping
Sleeping
File size: 1,593 Bytes
a03a89b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 | """Additional environment smoke tests."""
from __future__ import annotations
import pytest
minigrid = pytest.importorskip("minigrid")
assert minigrid is not None
from MiniGridEnv.env.config import EnvConfig
from MiniGridEnv.env.minigrid_env import MiniGridEnvironment
from MiniGridEnv.env.models import MiniGridAction
def test_step_before_reset_raises():
env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
with pytest.raises(RuntimeError):
env.step(MiniGridAction(command="go forward"))
def test_step_after_done_raises():
env = MiniGridEnvironment(
config=EnvConfig(level_name="GoToRedBall", max_steps_override=1)
)
env.reset(seed=1)
env.step(MiniGridAction(command="go forward"))
with pytest.raises(RuntimeError):
env.step(MiniGridAction(command="go forward"))
def test_reset_creates_new_episode_id():
env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
env.reset(seed=1)
episode_one = env.state.episode_id
env.reset(seed=2)
episode_two = env.state.episode_id
assert episode_one != episode_two
def test_reset_honors_explicit_episode_id():
env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
env.reset(seed=1, episode_id="episode-123")
assert env.state.episode_id == "episode-123"
def test_history_can_be_disabled():
env = MiniGridEnvironment(
config=EnvConfig(level_name="GoToRedBall", include_history=False)
)
env.reset(seed=5)
obs = env.step(MiniGridAction(command="go forward"))
assert obs.history == []
|