File size: 2,462 Bytes
a03a89b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""Parse free-form model text into MiniGrid discrete actions."""

from __future__ import annotations

import re

CANONICAL_ACTION_TO_INDEX: dict[str, int] = {
    "turn left": 0,
    "turn right": 1,
    "go forward": 2,
    "pickup": 3,
    "drop": 4,
    "toggle": 5,
    "done": 6,
}

ACTION_MAP: dict[str, str] = {
    "turn left": "turn left",
    "turn right": "turn right",
    "go forward": "go forward",
    "move forward": "go forward",
    "forward": "go forward",
    "pickup": "pickup",
    "pick up": "pickup",
    "grab": "pickup",
    "drop": "drop",
    "toggle": "toggle",
    "open": "toggle",
    "close": "toggle",
    "done": "done",
    "wait": "done",
    "noop": "done",
}

ALIASES: dict[str, str] = {
    "left": "turn left",
    "right": "turn right",
    "ahead": "go forward",
    "step": "go forward",
    "walk": "go forward",
    "take": "pickup",
    "get": "pickup",
    "release": "drop",
    "put down": "drop",
    "unlock": "toggle",
    "switch": "toggle",
    "stop": "done",
}

_ACTION_PATTERN = re.compile(r"action\s*:\s*(.+)", re.IGNORECASE)


def _extract_structured_action(text: str) -> str:
    """Extract action payload from `Action: ...` format when present."""
    match = _ACTION_PATTERN.search(text)
    if not match:
        return text
    candidate = match.group(1).strip()
    return candidate.splitlines()[0].strip()


def _match_from_map(cleaned: str, mapping: dict[str, str]) -> str | None:
    if cleaned in mapping:
        return mapping[cleaned]

    best_key = None
    best_len = -1
    for key in mapping:
        if key in cleaned and len(key) > best_len:
            best_key = key
            best_len = len(key)
    if best_key is None:
        return None
    return mapping[best_key]


def parse_action(text: str) -> tuple[int, str, bool]:
    """Parse model output text into `(action_index, canonical_action, is_valid)`."""
    cleaned = (text or "").strip().lower()
    if not cleaned:
        return CANONICAL_ACTION_TO_INDEX["go forward"], "go forward", False

    cleaned = _extract_structured_action(cleaned)

    canonical = _match_from_map(cleaned, ACTION_MAP)
    if canonical is not None:
        return CANONICAL_ACTION_TO_INDEX[canonical], canonical, True

    canonical = _match_from_map(cleaned, ALIASES)
    if canonical is not None:
        return CANONICAL_ACTION_TO_INDEX[canonical], canonical, True

    return CANONICAL_ACTION_TO_INDEX["go forward"], "go forward", False