Wildfire-Containment-Simulator / env /action_parser.py
Eshit's picture
Deploy to HF Space
363abf3
"""
Robust LLM output β†’ Action parser with 3-layer fallback.
Layer 1: Direct JSON parse
Layer 2: Regex field extraction
Layer 3: Safe IDLE fallback
"""
from __future__ import annotations
import json
import re
from typing import TYPE_CHECKING, Tuple
from .models import Action, ActionType, Direction
if TYPE_CHECKING:
from .models import Observation
_SAFE_IDLE = Action(action_type=ActionType.IDLE, reason="parse_failure")
_ACTION_TYPES = {a.value for a in ActionType}
_DIRECTIONS = {d.value for d in Direction}
def parse_action(llm_output: str, obs: "Observation") -> Tuple[Action, str]:
"""
Convert raw LLM text into a validated Action.
Returns (action, status) where status is one of:
"json_success", "regex_fallback", "safe_idle"
"""
grid_rows = len(obs.grid)
grid_cols = len(obs.grid[0]) if grid_rows > 0 else 0
# Layer 1 β€” direct JSON
action, status = _try_json(llm_output)
if action is not None:
action = _bounds_check(action, grid_rows, grid_cols)
return action, status
# Layer 2 β€” regex
action, status = _try_regex(llm_output)
if action is not None:
action = _bounds_check(action, grid_rows, grid_cols)
return action, status
# Layer 3 β€” safe fallback
return _SAFE_IDLE, "safe_idle"
# ── Layer 1 ──────────────────────────────────────────────────
def _try_json(text: str) -> Tuple[Action | None, str]:
raw = _extract_json_block(text)
if raw is None:
return None, "safe_idle"
try:
data = json.loads(raw)
if not isinstance(data, dict):
return None, "safe_idle"
# Normalise action_type casing
if "action_type" in data:
data["action_type"] = str(data["action_type"]).lower()
if data.get("action_type") not in _ACTION_TYPES:
return None, "safe_idle"
action = Action(**data)
return action, "json_success"
except Exception:
return None, "safe_idle"
def _extract_json_block(text: str) -> str | None:
"""Find first balanced {...} block, stripping ```json fences."""
# Strip code fences
text = re.sub(r"```(?:json)?\s*", "", text)
text = text.replace("```", "")
start = text.find("{")
if start == -1:
return None
depth = 0
for i, ch in enumerate(text[start:], start=start):
if ch == "{":
depth += 1
elif ch == "}":
depth -= 1
if depth == 0:
return text[start : i + 1]
return None
# ── Layer 2 ──────────────────────────────────────────────────
def _try_regex(text: str) -> Tuple[Action | None, str]:
# action_type
at_match = re.search(
r'action_type["\s:]+["\']?(' + "|".join(_ACTION_TYPES) + r")[\"']?",
text,
re.IGNORECASE,
)
if not at_match:
return None, "safe_idle"
action_type = at_match.group(1).lower()
def _str(pattern: str) -> str | None:
m = re.search(pattern, text, re.IGNORECASE)
return m.group(1) if m else None
def _int(pattern: str) -> int | None:
m = re.search(pattern, text, re.IGNORECASE)
return int(m.group(1)) if m else None
crew_id = _str(r'crew_id["\s:]+["\']?(crew_\d+)["\']?')
tanker_id = _str(r'tanker_id["\s:]+["\']?(tanker_\d+)["\']?')
target_row = _int(r'target_row["\s:]+(\d+)')
target_col = _int(r'target_col["\s:]+(\d+)')
direction_raw = _str(
r'direction["\s:]+["\']?(' + "|".join(_DIRECTIONS) + r")[\"']?"
)
direction = direction_raw.upper() if direction_raw else None
try:
action = Action(
action_type=action_type,
crew_id=crew_id,
tanker_id=tanker_id,
target_row=target_row,
target_col=target_col,
direction=direction,
)
return action, "regex_fallback"
except Exception:
return None, "safe_idle"
# ── Bounds check ─────────────────────────────────────────────
def _bounds_check(action: Action, grid_rows: int, grid_cols: int) -> Action:
"""Downgrade to IDLE if target coords are outside the grid."""
row, col = action.target_row, action.target_col
if row is None and col is None:
return action
if row is None or col is None:
return _SAFE_IDLE
if not (0 <= row < grid_rows and 0 <= col < grid_cols):
return _SAFE_IDLE
return action