File size: 5,537 Bytes
a03a89b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6951424
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""Contract-style tests for MiniGridEnv reset/step/state behavior."""

from __future__ import annotations

import pytest

minigrid = pytest.importorskip("minigrid")
assert minigrid is not None

from MiniGridEnv.env.config import EnvConfig
from MiniGridEnv.env.minigrid_env import MiniGridEnvironment
from MiniGridEnv.env.models import MiniGridAction, MiniGridObservation


def test_reset_returns_valid_observation():
    env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
    obs = env.reset(seed=123)
    assert isinstance(obs, MiniGridObservation)
    assert obs.step_idx == 0
    assert obs.history == []
    assert obs.done is False
    assert isinstance(obs.text, str) and len(obs.text.strip()) > 0


def test_observation_text_is_natural_language():
    env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
    obs = env.reset(seed=123)
    assert "Mission:" in obs.text
    assert "You are facing" in obs.text


def test_mission_is_populated():
    env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
    obs = env.reset(seed=123)
    assert isinstance(obs.mission, str) and len(obs.mission.strip()) > 0


def test_step_accepts_valid_command():
    env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
    env.reset(seed=123)
    obs = env.step(MiniGridAction(command="go forward"))
    assert isinstance(obs, MiniGridObservation)
    assert obs.step_idx == 1


def test_step_handles_invalid_command():
    env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
    env.reset(seed=123)
    obs = env.step(MiniGridAction(command="fly away"))
    assert isinstance(obs, MiniGridObservation)
    assert obs.last_action == "go forward"
    assert env.state.invalid_actions == 1


def test_episode_terminates_on_success():
    try:
        from minigrid.envs.babyai import BotAgent  # type: ignore
    except Exception:
        pytest.skip("BotAgent unavailable for success-path test")

    int_to_text = {
        0: "turn left",
        1: "turn right",
        2: "go forward",
        3: "pickup",
        4: "drop",
        5: "toggle",
        6: "done",
    }

    env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
    obs = env.reset(seed=123)
    bot = BotAgent(env._gym_env.unwrapped)  # type: ignore[attr-defined]

    while not obs.done and obs.step_idx < obs.max_steps:
        raw_obs = env._last_obs  # type: ignore[attr-defined]
        action = bot.act(raw_obs)
        obs = env.step(MiniGridAction(command=int_to_text.get(int(action), "done")))

    assert obs.done is True
    assert env.state.completed is True
    assert env.state.total_reward > 0.0


def test_episode_truncates_on_max_steps():
    env = MiniGridEnvironment(
        config=EnvConfig(level_name="GoToRedBall", max_steps_override=1)
    )
    env.reset(seed=123)
    obs = env.step(MiniGridAction(command="go forward"))
    assert obs.done is True
    assert env.state.truncated is True


def test_deterministic_with_seed():
    actions = ["turn left", "turn right", "go forward", "go forward"]

    env_a = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
    env_b = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
    obs_a = env_a.reset(seed=777)
    obs_b = env_b.reset(seed=777)

    assert obs_a.text == obs_b.text
    for command in actions:
        if obs_a.done or obs_b.done:
            assert obs_a.done == obs_b.done
            break
        obs_a = env_a.step(MiniGridAction(command=command))
        obs_b = env_b.step(MiniGridAction(command=command))
        assert obs_a.text == obs_b.text
        assert obs_a.done == obs_b.done


def test_state_tracks_metrics():
    env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
    env.reset(seed=123)
    env.step(MiniGridAction(command="go forward"))
    state = env.state
    assert state.steps_taken == 1
    assert state.valid_actions == 1
    assert state.action_distribution.get("go forward", 0) == 1


def test_all_seven_actions_accepted():
    commands = [
        "turn left",
        "turn right",
        "go forward",
        "pickup",
        "drop",
        "toggle",
        "done",
    ]
    env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
    for idx, command in enumerate(commands):
        env.reset(seed=100 + idx)
        obs = env.step(MiniGridAction(command=command))
        assert isinstance(obs, MiniGridObservation)


def test_history_grows_each_step():
    env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
    env.reset(seed=123)
    obs = env.step(MiniGridAction(command="go forward"))
    assert len(obs.history) == 1
    obs = env.step(MiniGridAction(command="turn left"))
    assert len(obs.history) == 2


def test_different_levels_produce_different_missions():
    env_a = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
    env_b = MiniGridEnvironment(config=EnvConfig(level_name="GoToObj"))
    obs_a = env_a.reset(seed=123)
    obs_b = env_b.reset(seed=123)
    assert obs_a.mission != obs_b.mission


def test_reset_level_kwarg_switches_level_on_same_env():
    """Remote PT clients can pass level/level_name in reset payload."""
    env = MiniGridEnvironment(config=EnvConfig(level_name="GoToRedBall"))
    obs_ball = env.reset(seed=123)
    assert obs_ball.level_name == "GoToRedBall"
    obs_obj = env.reset(seed=123, level_name="GoToObj")
    assert obs_obj.level_name == "GoToObj"
    assert obs_obj.mission != obs_ball.mission