File size: 896 Bytes
a03a89b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
"""BabyAI built-in bot baseline (near-optimal upper bound)."""

from __future__ import annotations

try:
    from ..env.models import MiniGridAction, MiniGridObservation
except ImportError:
    from env.models import MiniGridAction, MiniGridObservation

INT_TO_TEXT = {
    0: "turn left",
    1: "turn right",
    2: "go forward",
    3: "pickup",
    4: "drop",
    5: "toggle",
    6: "done",
}


class BabyAIBotBaseline:
    """Adapter for BabyAI's symbolic planner bot."""

    def __init__(self, gym_env):
        from minigrid.envs.babyai import BotAgent  # type: ignore

        self._bot = BotAgent(gym_env.unwrapped)

    def select_action(
        self, obs: MiniGridObservation, raw_obs: dict
    ) -> MiniGridAction:
        del obs
        action_int = self._bot.act(raw_obs)
        command = INT_TO_TEXT.get(int(action_int), "done")
        return MiniGridAction(command=command)