| import asyncio |
| import os |
| import re |
| import shutil |
| import sys |
| from typing import Callable, Dict, List, Optional |
|
|
| from mllm.markov_games.rollout_tree import ChatTurn |
|
|
| try: |
| import rstr |
| except Exception: |
| rstr = None |
|
|
|
|
| def _clear_terminal() -> None: |
| """ |
| Clear the terminal screen in a cross-platform manner. |
| """ |
| if sys.stdout.isatty(): |
| os.system("cls" if os.name == "nt" else "clear") |
|
|
|
|
| def _terminal_width(default: int = 100) -> int: |
| try: |
| return shutil.get_terminal_size().columns |
| except Exception: |
| return default |
|
|
|
|
| def _horizontal_rule(char: str = "─") -> str: |
| width = max(20, _terminal_width() - 2) |
| return char * width |
|
|
|
|
| class _Style: |
| |
| RESET = "\033[0m" |
| BOLD = "\033[1m" |
| DIM = "\033[2m" |
| |
| FG_BLUE = "\033[94m" |
| FG_GREEN = "\033[92m" |
| FG_YELLOW = "\033[93m" |
| FG_RED = "\033[91m" |
| FG_MAGENTA = "\033[95m" |
| FG_CYAN = "\033[96m" |
|
|
|
|
| def _render_chat(state) -> str: |
| """ |
| Render prior messages in a compact, readable terminal format. |
| |
| Expected message dict keys: {"role": str, "content": str, ...} |
| """ |
| lines: List[str] = [] |
| lines.append(_horizontal_rule()) |
| lines.append(f"{_Style.FG_BLUE}{_Style.BOLD} Conversation so far {_Style.RESET}") |
| lines.append(_horizontal_rule()) |
| for chat in state: |
| role = chat.role |
| content = str(chat.content).strip() |
| |
| if role == "assistant": |
| header = f"{_Style.FG_GREEN}{_Style.BOLD}HUMAN--🧑💻{_Style.RESET}" |
| elif role == "user": |
| header = f"{_Style.FG_BLUE}{_Style.BOLD}USER--⚙️{_Style.RESET}" |
| else: |
| header = f"[{_Style.DIM}{role.upper()}{_Style.RESET}]" |
| lines.append(header) |
| |
| for line in content.splitlines() or [""]: |
| lines.append(f" {line}") |
| lines.append("") |
| lines.append(_horizontal_rule()) |
| return "\n".join(lines) |
|
|
|
|
| async def _async_input(prompt_text: str) -> str: |
| """Non-blocking input using a background thread.""" |
| return await asyncio.to_thread(input, prompt_text) |
|
|
|
|
| def _short_regex_example(regex: str, max_len: int = 30) -> Optional[str]: |
| """ |
| Try to produce a short example string that matches the regex. |
| We attempt multiple times and pick the first <= max_len. |
| """ |
| if rstr is None: |
| return None |
| try: |
| for _ in range(20): |
| candidate = rstr.xeger(regex) |
| if len(candidate) <= max_len: |
| return candidate |
| |
| return None |
| except Exception: |
| return None |
|
|
|
|
| def _detect_input_type(regex: str | None) -> tuple[str, str, str]: |
| """ |
| Detect what type of input is expected based on the regex pattern. |
| Returns (input_type, start_tag, end_tag) |
| """ |
| if regex is None: |
| return "text", "", "" |
|
|
| if "message_start" in regex and "message_end" in regex: |
| return "message", "<<message_start>>", "<<message_end>>" |
| elif "proposal_start" in regex and "proposal_end" in regex: |
| return "proposal", "<<proposal_start>>", "<<proposal_end>>" |
| else: |
| return "text", "", "" |
|
|
|
|
| async def human_policy(state, agent_id, regex: str | None = None) -> str: |
| """ |
| Async human-in-the-loop policy. |
| |
| - Displays prior conversation context in the terminal. |
| - Prompts the user for a response. |
| - If a regex is provided, validates and re-prompts until it matches. |
| - Automatically adds formatting tags based on expected input type. |
| |
| Args: |
| prompt: Chat history as a list of {role, content} dicts. |
| regex: Optional fullmatch validation pattern. |
| |
| Returns: |
| The user's validated response string. |
| """ |
| |
| input_type, start_tag, end_tag = _detect_input_type(regex) |
|
|
| while True: |
| _clear_terminal() |
| print(_render_chat(state)) |
|
|
| if regex: |
| example = _short_regex_example(regex, max_len=30) |
| print( |
| f"{_Style.FG_MAGENTA}{_Style.BOLD}Expected format (regex fullmatch):{_Style.RESET}" |
| ) |
| print(f" {_Style.FG_MAGENTA}{regex}{_Style.RESET}") |
| if example: |
| print( |
| f"{_Style.FG_CYAN}Example (random, <=30 chars):{_Style.RESET} {example}" |
| ) |
| print(_horizontal_rule(".")) |
|
|
| |
| if input_type == "message": |
| print( |
| f"{_Style.FG_YELLOW}Type your message content (formatting will be added automatically):{_Style.RESET}" |
| ) |
| elif input_type == "proposal": |
| print( |
| f"{_Style.FG_YELLOW}Type your proposal (number only, formatting will be added automatically):{_Style.RESET}" |
| ) |
| else: |
| print( |
| f"{_Style.FG_YELLOW}Type your response and press Enter.{_Style.RESET}" |
| ) |
|
|
| print( |
| f"{_Style.DIM}Commands: /help to view commands, /refresh to re-render, /quit to abort{_Style.RESET}" |
| ) |
| else: |
| print( |
| f"{_Style.FG_YELLOW}Type your response and press Enter.{_Style.RESET} {_Style.DIM}(/help for commands){_Style.RESET}" |
| ) |
|
|
| user_in = (await _async_input("> ")).rstrip("\n") |
|
|
| |
| if user_in.strip().lower() in {"/help", "/h"}: |
| print(f"\n{_Style.FG_CYAN}{_Style.BOLD}Available commands:{_Style.RESET}") |
| print( |
| f" {_Style.FG_CYAN}/help{_Style.RESET} or {_Style.FG_CYAN}/h{_Style.RESET} Show this help" |
| ) |
| print( |
| f" {_Style.FG_CYAN}/refresh{_Style.RESET} or {_Style.FG_CYAN}/r{_Style.RESET} Re-render the conversation and prompt" |
| ) |
| print( |
| f" {_Style.FG_CYAN}/quit{_Style.RESET} or {_Style.FG_CYAN}/q{_Style.RESET} Abort the run (raises KeyboardInterrupt)" |
| ) |
| await asyncio.sleep(1.0) |
| continue |
| if user_in.strip().lower() in {"/refresh", "/r"}: |
| continue |
| if user_in.strip().lower() in {"/quit", "/q"}: |
| raise KeyboardInterrupt("Human aborted run from human_policy") |
|
|
| |
| if start_tag and end_tag: |
| formatted_input = f"{start_tag}{user_in}{end_tag}" |
| else: |
| formatted_input = user_in |
|
|
| if regex is None: |
| return ChatTurn( |
| role="assistant", agent_id=agent_id, content=formatted_input |
| ) |
|
|
| |
| try: |
| pattern = re.compile(regex) |
| except re.error as e: |
| |
| print( |
| f"{_Style.FG_RED}Warning:{_Style.RESET} Provided regex is invalid: {e}. Accepting input without validation." |
| ) |
| await asyncio.sleep(0.5) |
| return ChatTurn( |
| role="assistant", agent_id=agent_id, content=formatted_input |
| ) |
|
|
| if pattern.fullmatch(formatted_input): |
| return ChatTurn( |
| role="assistant", agent_id=agent_id, content=formatted_input |
| ) |
|
|
| |
| print("") |
| print( |
| f"{_Style.FG_RED}{_Style.BOLD}Input did not match the required format.{_Style.RESET} Please try again." |
| ) |
|
|
| if input_type == "message": |
| print( |
| f"You entered: {_Style.FG_CYAN}{start_tag}{user_in}{end_tag}{_Style.RESET}" |
| ) |
| print(f"Just type the message content without tags.") |
| elif input_type == "proposal": |
| print( |
| f"You entered: {_Style.FG_CYAN}{start_tag}{user_in}{end_tag}{_Style.RESET}" |
| ) |
| print(f"Just type the number without tags.") |
| else: |
| print(f"Expected (regex):") |
| print(f" {_Style.FG_MAGENTA}{regex}{_Style.RESET}") |
|
|
| print(_horizontal_rule(".")) |
| print(f"{_Style.FG_YELLOW}Press Enter to retry...{_Style.RESET}") |
| await _async_input("") |
|
|
|
|
| def get_human_policies() -> Dict[str, Callable[[List[Dict]], str]]: |
| """ |
| Expose the human policy in the same map shape used elsewhere. |
| """ |
| |
| return {"human_policy": human_policy} |
|
|