"""Convert MiniGrid/BabyAI observations into rich natural language text.""" from __future__ import annotations from typing import Any import numpy as np OBJECT_TYPES = { 0: "unseen", 1: "empty", 2: "wall", 3: "floor", 4: "door", 5: "key", 6: "ball", 7: "box", 8: "goal", 9: "lava", 10: "agent", } COLORS = { 0: "red", 1: "green", 2: "blue", 3: "purple", 4: "yellow", 5: "grey", } DOOR_STATES = {0: "open", 1: "closed", 2: "locked"} DIRECTION_NAMES = {0: "east", 1: "south", 2: "west", 3: "north"} _AGENT_ROW = 6 _AGENT_COL = 3 def _format_object_name(obj_type: str, color: str | None, state: str | None = None) -> str: if obj_type == "door": prefix = f"{state} {color}".strip() if color else (state or "door") return f"a {prefix} door".replace(" ", " ").strip() if color: return f"a {color} {obj_type}" return f"a {obj_type}" def _relative_position_phrase(rel_row: int, rel_col: int) -> str: parts: list[str] = [] if rel_row < 0: steps_ahead = abs(rel_row) parts.append(f"{steps_ahead} step{'s' if steps_ahead != 1 else ''} ahead") elif rel_row > 0: steps_behind = rel_row parts.append(f"{steps_behind} step{'s' if steps_behind != 1 else ''} behind") if rel_col < 0: steps_left = abs(rel_col) parts.append(f"{steps_left} to your left") elif rel_col > 0: steps_right = rel_col parts.append(f"{steps_right} to your right") if not parts: return "at your position" if len(parts) == 1: return parts[0] return f"{parts[0]} and {parts[1]}" def _describe_cell(grid: np.ndarray, row: int, col: int) -> str: if row < 0 or row >= grid.shape[0] or col < 0 or col >= grid.shape[1]: return "a wall boundary" obj_idx = int(grid[row, col, 0]) color_idx = int(grid[row, col, 1]) state_idx = int(grid[row, col, 2]) obj_type = OBJECT_TYPES.get(obj_idx, "unknown") color = COLORS.get(color_idx) if obj_type in {"empty", "floor"}: return "empty space" if obj_type == "unseen": return "unseen area" if obj_type == "wall": return "a wall" if obj_type == "door": return _format_object_name("door", color, DOOR_STATES.get(state_idx, "closed")) if obj_type == "lava": return "lava" return _format_object_name(obj_type, color) def _scan_objects(grid: np.ndarray) -> list[dict[str, Any]]: """Extract notable interactive objects with relative positions.""" objects: list[dict[str, Any]] = [] for row in range(grid.shape[0]): for col in range(grid.shape[1]): obj_idx = int(grid[row, col, 0]) color_idx = int(grid[row, col, 1]) state_idx = int(grid[row, col, 2]) obj_type = OBJECT_TYPES.get(obj_idx, "unknown") if obj_type in {"unseen", "empty", "wall", "floor", "agent"}: continue rel_row = row - _AGENT_ROW rel_col = col - _AGENT_COL state = DOOR_STATES.get(state_idx) if obj_type == "door" else None color = COLORS.get(color_idx) objects.append( { "type": obj_type, "color": color, "state": state, "row": row, "col": col, "rel_row": rel_row, "rel_col": rel_col, "distance": abs(rel_row) + abs(rel_col), "direction_desc": _relative_position_phrase(rel_row, rel_col), } ) objects.sort(key=lambda item: (item["distance"], item["row"], item["col"])) return objects def _describe_immediate_surroundings(grid: np.ndarray) -> str: """Describe the nearest cells around the agent.""" ahead = _describe_cell(grid, _AGENT_ROW - 1, _AGENT_COL) left = _describe_cell(grid, _AGENT_ROW, _AGENT_COL - 1) right = _describe_cell(grid, _AGENT_ROW, _AGENT_COL + 1) return ( f"Directly ahead: {ahead}.\n" f"To your left: {left}.\n" f"To your right: {right}." ) def _describe_path_ahead(grid: np.ndarray) -> str: """Describe what appears in the straight-ahead lane.""" segments: list[str] = [] empty_run = 0 for row in range(_AGENT_ROW - 1, -1, -1): cell_desc = _describe_cell(grid, row, _AGENT_COL) if cell_desc == "empty space": empty_run += 1 continue if empty_run > 0: segments.append( f"empty space for {empty_run} step{'s' if empty_run != 1 else ''}" ) empty_run = 0 segments.append(cell_desc) if cell_desc in {"a wall", "a wall boundary", "unseen area"}: break if empty_run > 0: segments.append( f"empty space for {empty_run} step{'s' if empty_run != 1 else ''}" ) if not segments: return "Looking ahead: no clear information." if len(segments) == 1: return f"Looking ahead: {segments[0]}." return f"Looking ahead: {', then '.join(segments)}." def _describe_notable_objects(objects: list[dict[str, Any]]) -> str: """List visible interactive objects with positions.""" if not objects: return "Notable objects: none visible." lines = ["Notable objects:"] for obj in objects: name = _format_object_name(obj["type"], obj.get("color"), obj.get("state")) lines.append(f"- {name} ({obj['direction_desc']}).") return "\n".join(lines) def _describe_carrying_status(carrying: Any) -> str: """Describe what the agent is currently carrying.""" if carrying is None: return "You are carrying: nothing." if isinstance(carrying, dict): obj_type = carrying.get("type") color = carrying.get("color") else: obj_type = getattr(carrying, "type", None) color = getattr(carrying, "color", None) if obj_type is None: return "You are carrying: an object." if color: return f"You are carrying: a {color} {obj_type}." return f"You are carrying: a {obj_type}." def _render_ascii_grid(grid: np.ndarray) -> str: """Render a compact ASCII view for debugging.""" glyphs = { "unseen": "?", "empty": ".", "wall": "#", "floor": ".", "door": "D", "key": "K", "ball": "B", "box": "X", "goal": "G", "lava": "L", "agent": "A", } rows: list[str] = [] for row in range(grid.shape[0]): chars: list[str] = [] for col in range(grid.shape[1]): obj_type = OBJECT_TYPES.get(int(grid[row, col, 0]), "unknown") chars.append(glyphs.get(obj_type, "!")) rows.append("".join(chars)) return "\n".join(rows) def grid_to_text( obs: dict[str, Any], carrying: Any = None, include_raw_grid: bool = False ) -> str: """Convert MiniGrid raw observation dict to a rich language description.""" grid = obs.get("image") if grid is None: return "Mission: unknown.\nObservation is missing grid image." mission = str(obs.get("mission", "")).strip() or "unknown mission" direction = int(obs.get("direction", 0)) direction_name = DIRECTION_NAMES.get(direction, "unknown") if not isinstance(grid, np.ndarray): grid = np.asarray(grid) objects = _scan_objects(grid) parts = [ f"Mission: {mission}", f"You are facing {direction_name}.", "", _describe_immediate_surroundings(grid), _describe_path_ahead(grid), _describe_notable_objects(objects), _describe_carrying_status(carrying), ] if include_raw_grid: parts.extend(["", "Raw grid (debug):", _render_ascii_grid(grid)]) return "\n".join(part for part in parts if part is not None)