MiniGridEnv / env /grid_to_text.py
yashu2000's picture
Upload folder using huggingface_hub
a03a89b verified
"""Convert MiniGrid/BabyAI observations into rich natural language text."""
from __future__ import annotations
from typing import Any
import numpy as np
OBJECT_TYPES = {
0: "unseen",
1: "empty",
2: "wall",
3: "floor",
4: "door",
5: "key",
6: "ball",
7: "box",
8: "goal",
9: "lava",
10: "agent",
}
COLORS = {
0: "red",
1: "green",
2: "blue",
3: "purple",
4: "yellow",
5: "grey",
}
DOOR_STATES = {0: "open", 1: "closed", 2: "locked"}
DIRECTION_NAMES = {0: "east", 1: "south", 2: "west", 3: "north"}
_AGENT_ROW = 6
_AGENT_COL = 3
def _format_object_name(obj_type: str, color: str | None, state: str | None = None) -> str:
if obj_type == "door":
prefix = f"{state} {color}".strip() if color else (state or "door")
return f"a {prefix} door".replace(" ", " ").strip()
if color:
return f"a {color} {obj_type}"
return f"a {obj_type}"
def _relative_position_phrase(rel_row: int, rel_col: int) -> str:
parts: list[str] = []
if rel_row < 0:
steps_ahead = abs(rel_row)
parts.append(f"{steps_ahead} step{'s' if steps_ahead != 1 else ''} ahead")
elif rel_row > 0:
steps_behind = rel_row
parts.append(f"{steps_behind} step{'s' if steps_behind != 1 else ''} behind")
if rel_col < 0:
steps_left = abs(rel_col)
parts.append(f"{steps_left} to your left")
elif rel_col > 0:
steps_right = rel_col
parts.append(f"{steps_right} to your right")
if not parts:
return "at your position"
if len(parts) == 1:
return parts[0]
return f"{parts[0]} and {parts[1]}"
def _describe_cell(grid: np.ndarray, row: int, col: int) -> str:
if row < 0 or row >= grid.shape[0] or col < 0 or col >= grid.shape[1]:
return "a wall boundary"
obj_idx = int(grid[row, col, 0])
color_idx = int(grid[row, col, 1])
state_idx = int(grid[row, col, 2])
obj_type = OBJECT_TYPES.get(obj_idx, "unknown")
color = COLORS.get(color_idx)
if obj_type in {"empty", "floor"}:
return "empty space"
if obj_type == "unseen":
return "unseen area"
if obj_type == "wall":
return "a wall"
if obj_type == "door":
return _format_object_name("door", color, DOOR_STATES.get(state_idx, "closed"))
if obj_type == "lava":
return "lava"
return _format_object_name(obj_type, color)
def _scan_objects(grid: np.ndarray) -> list[dict[str, Any]]:
"""Extract notable interactive objects with relative positions."""
objects: list[dict[str, Any]] = []
for row in range(grid.shape[0]):
for col in range(grid.shape[1]):
obj_idx = int(grid[row, col, 0])
color_idx = int(grid[row, col, 1])
state_idx = int(grid[row, col, 2])
obj_type = OBJECT_TYPES.get(obj_idx, "unknown")
if obj_type in {"unseen", "empty", "wall", "floor", "agent"}:
continue
rel_row = row - _AGENT_ROW
rel_col = col - _AGENT_COL
state = DOOR_STATES.get(state_idx) if obj_type == "door" else None
color = COLORS.get(color_idx)
objects.append(
{
"type": obj_type,
"color": color,
"state": state,
"row": row,
"col": col,
"rel_row": rel_row,
"rel_col": rel_col,
"distance": abs(rel_row) + abs(rel_col),
"direction_desc": _relative_position_phrase(rel_row, rel_col),
}
)
objects.sort(key=lambda item: (item["distance"], item["row"], item["col"]))
return objects
def _describe_immediate_surroundings(grid: np.ndarray) -> str:
"""Describe the nearest cells around the agent."""
ahead = _describe_cell(grid, _AGENT_ROW - 1, _AGENT_COL)
left = _describe_cell(grid, _AGENT_ROW, _AGENT_COL - 1)
right = _describe_cell(grid, _AGENT_ROW, _AGENT_COL + 1)
return (
f"Directly ahead: {ahead}.\n"
f"To your left: {left}.\n"
f"To your right: {right}."
)
def _describe_path_ahead(grid: np.ndarray) -> str:
"""Describe what appears in the straight-ahead lane."""
segments: list[str] = []
empty_run = 0
for row in range(_AGENT_ROW - 1, -1, -1):
cell_desc = _describe_cell(grid, row, _AGENT_COL)
if cell_desc == "empty space":
empty_run += 1
continue
if empty_run > 0:
segments.append(
f"empty space for {empty_run} step{'s' if empty_run != 1 else ''}"
)
empty_run = 0
segments.append(cell_desc)
if cell_desc in {"a wall", "a wall boundary", "unseen area"}:
break
if empty_run > 0:
segments.append(
f"empty space for {empty_run} step{'s' if empty_run != 1 else ''}"
)
if not segments:
return "Looking ahead: no clear information."
if len(segments) == 1:
return f"Looking ahead: {segments[0]}."
return f"Looking ahead: {', then '.join(segments)}."
def _describe_notable_objects(objects: list[dict[str, Any]]) -> str:
"""List visible interactive objects with positions."""
if not objects:
return "Notable objects: none visible."
lines = ["Notable objects:"]
for obj in objects:
name = _format_object_name(obj["type"], obj.get("color"), obj.get("state"))
lines.append(f"- {name} ({obj['direction_desc']}).")
return "\n".join(lines)
def _describe_carrying_status(carrying: Any) -> str:
"""Describe what the agent is currently carrying."""
if carrying is None:
return "You are carrying: nothing."
if isinstance(carrying, dict):
obj_type = carrying.get("type")
color = carrying.get("color")
else:
obj_type = getattr(carrying, "type", None)
color = getattr(carrying, "color", None)
if obj_type is None:
return "You are carrying: an object."
if color:
return f"You are carrying: a {color} {obj_type}."
return f"You are carrying: a {obj_type}."
def _render_ascii_grid(grid: np.ndarray) -> str:
"""Render a compact ASCII view for debugging."""
glyphs = {
"unseen": "?",
"empty": ".",
"wall": "#",
"floor": ".",
"door": "D",
"key": "K",
"ball": "B",
"box": "X",
"goal": "G",
"lava": "L",
"agent": "A",
}
rows: list[str] = []
for row in range(grid.shape[0]):
chars: list[str] = []
for col in range(grid.shape[1]):
obj_type = OBJECT_TYPES.get(int(grid[row, col, 0]), "unknown")
chars.append(glyphs.get(obj_type, "!"))
rows.append("".join(chars))
return "\n".join(rows)
def grid_to_text(
obs: dict[str, Any], carrying: Any = None, include_raw_grid: bool = False
) -> str:
"""Convert MiniGrid raw observation dict to a rich language description."""
grid = obs.get("image")
if grid is None:
return "Mission: unknown.\nObservation is missing grid image."
mission = str(obs.get("mission", "")).strip() or "unknown mission"
direction = int(obs.get("direction", 0))
direction_name = DIRECTION_NAMES.get(direction, "unknown")
if not isinstance(grid, np.ndarray):
grid = np.asarray(grid)
objects = _scan_objects(grid)
parts = [
f"Mission: {mission}",
f"You are facing {direction_name}.",
"",
_describe_immediate_surroundings(grid),
_describe_path_ahead(grid),
_describe_notable_objects(objects),
_describe_carrying_status(carrying),
]
if include_raw_grid:
parts.extend(["", "Raw grid (debug):", _render_ascii_grid(grid)])
return "\n".join(part for part in parts if part is not None)