File size: 8,852 Bytes
1c8c60e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
import asyncio
import os
import re
import shutil
import sys
from typing import Callable, Dict, List, Optional

from mllm.markov_games.rollout_tree import ChatTurn

try:
    import rstr  # For generating example strings from regex
except Exception:  # pragma: no cover
    rstr = None


def _clear_terminal() -> None:
    """
    Clear the terminal screen in a cross-platform manner.
    """
    if sys.stdout.isatty():
        os.system("cls" if os.name == "nt" else "clear")


def _terminal_width(default: int = 100) -> int:
    try:
        return shutil.get_terminal_size().columns
    except Exception:
        return default


def _horizontal_rule(char: str = "─") -> str:
    width = max(20, _terminal_width() - 2)
    return char * width


class _Style:
    # ANSI colors (bright, readable)
    RESET = "\033[0m"
    BOLD = "\033[1m"
    DIM = "\033[2m"
    # Foreground colors
    FG_BLUE = "\033[94m"  # user/system headers
    FG_GREEN = "\033[92m"  # human response header
    FG_YELLOW = "\033[93m"  # notices
    FG_RED = "\033[91m"  # errors
    FG_MAGENTA = "\033[95m"  # regex
    FG_CYAN = "\033[96m"  # tips


def _render_chat(state) -> str:
    """
    Render prior messages in a compact, readable terminal format.

    Expected message dict keys: {"role": str, "content": str, ...}
    """
    lines: List[str] = []
    lines.append(_horizontal_rule())
    lines.append(f"{_Style.FG_BLUE}{_Style.BOLD} Conversation so far {_Style.RESET}")
    lines.append(_horizontal_rule())
    for chat in state:
        role = chat.role
        content = str(chat.content).strip()
        # Map roles to display names and colors/emojis
        if role == "assistant":
            header = f"{_Style.FG_GREEN}{_Style.BOLD}HUMAN--🧑‍💻{_Style.RESET}"
        elif role == "user":
            header = f"{_Style.FG_BLUE}{_Style.BOLD}USER--⚙️{_Style.RESET}"
        else:
            header = f"[{_Style.DIM}{role.upper()}{_Style.RESET}]"
        lines.append(header)
        # Indent content for readability
        for line in content.splitlines() or [""]:
            lines.append(f"  {line}")
        lines.append("")
    lines.append(_horizontal_rule())
    return "\n".join(lines)


async def _async_input(prompt_text: str) -> str:
    """Non-blocking input using a background thread."""
    return await asyncio.to_thread(input, prompt_text)


def _short_regex_example(regex: str, max_len: int = 30) -> Optional[str]:
    """
    Try to produce a short example string that matches the regex.
    We attempt multiple times and pick the first <= max_len.
    """
    if rstr is None:
        return None
    try:
        for _ in range(20):
            candidate = rstr.xeger(regex)
            if len(candidate) <= max_len:
                return candidate
        # Fallback to truncation (may break match, so don't return)
        return None
    except Exception:
        return None


def _detect_input_type(regex: str | None) -> tuple[str, str, str]:
    """
    Detect what type of input is expected based on the regex pattern.
    Returns (input_type, start_tag, end_tag)
    """
    if regex is None:
        return "text", "", ""

    if "message_start" in regex and "message_end" in regex:
        return "message", "<<message_start>>", "<<message_end>>"
    elif "proposal_start" in regex and "proposal_end" in regex:
        return "proposal", "<<proposal_start>>", "<<proposal_end>>"
    else:
        return "text", "", ""


async def human_policy(state, agent_id, regex: str | None = None) -> str:
    """
    Async human-in-the-loop policy.

    - Displays prior conversation context in the terminal.
    - Prompts the user for a response.
    - If a regex is provided, validates and re-prompts until it matches.
    - Automatically adds formatting tags based on expected input type.

    Args:
        prompt: Chat history as a list of {role, content} dicts.
        regex: Optional fullmatch validation pattern.

    Returns:
        The user's validated response string.
    """
    # Detect input type and formatting
    input_type, start_tag, end_tag = _detect_input_type(regex)

    while True:
        _clear_terminal()
        print(_render_chat(state))

        if regex:
            example = _short_regex_example(regex, max_len=30)
            print(
                f"{_Style.FG_MAGENTA}{_Style.BOLD}Expected format (regex fullmatch):{_Style.RESET}"
            )
            print(f"  {_Style.FG_MAGENTA}{regex}{_Style.RESET}")
            if example:
                print(
                    f"{_Style.FG_CYAN}Example (random, <=30 chars):{_Style.RESET} {example}"
                )
            print(_horizontal_rule("."))

            # Custom prompt based on input type
            if input_type == "message":
                print(
                    f"{_Style.FG_YELLOW}Type your message content (formatting will be added automatically):{_Style.RESET}"
                )
            elif input_type == "proposal":
                print(
                    f"{_Style.FG_YELLOW}Type your proposal (number only, formatting will be added automatically):{_Style.RESET}"
                )
            else:
                print(
                    f"{_Style.FG_YELLOW}Type your response and press Enter.{_Style.RESET}"
                )

            print(
                f"{_Style.DIM}Commands: /help to view commands, /refresh to re-render, /quit to abort{_Style.RESET}"
            )
        else:
            print(
                f"{_Style.FG_YELLOW}Type your response and press Enter.{_Style.RESET} {_Style.DIM}(/help for commands){_Style.RESET}"
            )

        user_in = (await _async_input("> ")).rstrip("\n")

        # Commands
        if user_in.strip().lower() in {"/help", "/h"}:
            print(f"\n{_Style.FG_CYAN}{_Style.BOLD}Available commands:{_Style.RESET}")
            print(
                f"  {_Style.FG_CYAN}/help{_Style.RESET} or {_Style.FG_CYAN}/h{_Style.RESET}     Show this help"
            )
            print(
                f"  {_Style.FG_CYAN}/refresh{_Style.RESET} or {_Style.FG_CYAN}/r{_Style.RESET}  Re-render the conversation and prompt"
            )
            print(
                f"  {_Style.FG_CYAN}/quit{_Style.RESET} or {_Style.FG_CYAN}/q{_Style.RESET}     Abort the run (raises KeyboardInterrupt)"
            )
            await asyncio.sleep(1.0)
            continue
        if user_in.strip().lower() in {"/refresh", "/r"}:
            continue
        if user_in.strip().lower() in {"/quit", "/q"}:
            raise KeyboardInterrupt("Human aborted run from human_policy")

        # Add formatting tags if needed
        if start_tag and end_tag:
            formatted_input = f"{start_tag}{user_in}{end_tag}"
        else:
            formatted_input = user_in

        if regex is None:
            return ChatTurn(
                role="assistant", agent_id=agent_id, content=formatted_input
            )

        # Validate against regex (fullmatch)
        try:
            pattern = re.compile(regex)
        except re.error as e:
            # If regex is invalid, fall back to accepting any input
            print(
                f"{_Style.FG_RED}Warning:{_Style.RESET} Provided regex is invalid: {e}. Accepting input without validation."
            )
            await asyncio.sleep(0.5)
            return ChatTurn(
                role="assistant", agent_id=agent_id, content=formatted_input
            )

        if pattern.fullmatch(formatted_input):
            return ChatTurn(
                role="assistant", agent_id=agent_id, content=formatted_input
            )

        # Show validation error and re-prompt
        print("")
        print(
            f"{_Style.FG_RED}{_Style.BOLD}Input did not match the required format.{_Style.RESET} Please try again."
        )

        if input_type == "message":
            print(
                f"You entered: {_Style.FG_CYAN}{start_tag}{user_in}{end_tag}{_Style.RESET}"
            )
            print(f"Just type the message content without tags.")
        elif input_type == "proposal":
            print(
                f"You entered: {_Style.FG_CYAN}{start_tag}{user_in}{end_tag}{_Style.RESET}"
            )
            print(f"Just type the number without tags.")
        else:
            print(f"Expected (regex):")
            print(f"  {_Style.FG_MAGENTA}{regex}{_Style.RESET}")

        print(_horizontal_rule("."))
        print(f"{_Style.FG_YELLOW}Press Enter to retry...{_Style.RESET}")
        await _async_input("")


def get_human_policies() -> Dict[str, Callable[[List[Dict]], str]]:
    """
    Expose the human policy in the same map shape used elsewhere.
    """
    # Type hint says Callable[[List[Dict]], str] but we intentionally return the async callable.
    return {"human_policy": human_policy}  # type: ignore[return-value]