MiniGridEnv / env /action_parser.py
yashu2000's picture
Upload folder using huggingface_hub
a03a89b verified
"""Parse free-form model text into MiniGrid discrete actions."""
from __future__ import annotations
import re
CANONICAL_ACTION_TO_INDEX: dict[str, int] = {
"turn left": 0,
"turn right": 1,
"go forward": 2,
"pickup": 3,
"drop": 4,
"toggle": 5,
"done": 6,
}
ACTION_MAP: dict[str, str] = {
"turn left": "turn left",
"turn right": "turn right",
"go forward": "go forward",
"move forward": "go forward",
"forward": "go forward",
"pickup": "pickup",
"pick up": "pickup",
"grab": "pickup",
"drop": "drop",
"toggle": "toggle",
"open": "toggle",
"close": "toggle",
"done": "done",
"wait": "done",
"noop": "done",
}
ALIASES: dict[str, str] = {
"left": "turn left",
"right": "turn right",
"ahead": "go forward",
"step": "go forward",
"walk": "go forward",
"take": "pickup",
"get": "pickup",
"release": "drop",
"put down": "drop",
"unlock": "toggle",
"switch": "toggle",
"stop": "done",
}
_ACTION_PATTERN = re.compile(r"action\s*:\s*(.+)", re.IGNORECASE)
def _extract_structured_action(text: str) -> str:
"""Extract action payload from `Action: ...` format when present."""
match = _ACTION_PATTERN.search(text)
if not match:
return text
candidate = match.group(1).strip()
return candidate.splitlines()[0].strip()
def _match_from_map(cleaned: str, mapping: dict[str, str]) -> str | None:
if cleaned in mapping:
return mapping[cleaned]
best_key = None
best_len = -1
for key in mapping:
if key in cleaned and len(key) > best_len:
best_key = key
best_len = len(key)
if best_key is None:
return None
return mapping[best_key]
def parse_action(text: str) -> tuple[int, str, bool]:
"""Parse model output text into `(action_index, canonical_action, is_valid)`."""
cleaned = (text or "").strip().lower()
if not cleaned:
return CANONICAL_ACTION_TO_INDEX["go forward"], "go forward", False
cleaned = _extract_structured_action(cleaned)
canonical = _match_from_map(cleaned, ACTION_MAP)
if canonical is not None:
return CANONICAL_ACTION_TO_INDEX[canonical], canonical, True
canonical = _match_from_map(cleaned, ALIASES)
if canonical is not None:
return CANONICAL_ACTION_TO_INDEX[canonical], canonical, True
return CANONICAL_ACTION_TO_INDEX["go forward"], "go forward", False