File size: 1,593 Bytes
a03a89b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
"""Additional environment smoke tests."""

from __future__ import annotations

import pytest

minigrid = pytest.importorskip("minigrid")
assert minigrid is not None

from MiniGridEnv.env.config import EnvConfig
from MiniGridEnv.env.minigrid_env import MiniGridEnvironment
from MiniGridEnv.env.models import MiniGridAction


def test_step_before_reset_raises():
    env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
    with pytest.raises(RuntimeError):
        env.step(MiniGridAction(command="go forward"))


def test_step_after_done_raises():
    env = MiniGridEnvironment(
        config=EnvConfig(level_name="GoToRedBall", max_steps_override=1)
    )
    env.reset(seed=1)
    env.step(MiniGridAction(command="go forward"))
    with pytest.raises(RuntimeError):
        env.step(MiniGridAction(command="go forward"))


def test_reset_creates_new_episode_id():
    env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
    env.reset(seed=1)
    episode_one = env.state.episode_id
    env.reset(seed=2)
    episode_two = env.state.episode_id
    assert episode_one != episode_two


def test_reset_honors_explicit_episode_id():
    env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
    env.reset(seed=1, episode_id="episode-123")
    assert env.state.episode_id == "episode-123"


def test_history_can_be_disabled():
    env = MiniGridEnvironment(
        config=EnvConfig(level_name="GoToRedBall", include_history=False)
    )
    env.reset(seed=5)
    obs = env.step(MiniGridAction(command="go forward"))
    assert obs.history == []