Spaces:
Sleeping
Sleeping
| import logging | |
| import os | |
| import os.path as osp | |
| from typing import Union | |
| from colorama import Fore | |
| from colorama import Style as CRStyle | |
| from prompt_toolkit import prompt | |
| from prompt_toolkit.completion import WordCompleter | |
| from prompt_toolkit.styles import Style | |
| from rich.console import Console | |
| from agentreview.utility.utils import get_rebuttal_dir, load_llm_ac_decisions, \ | |
| save_llm_ac_decisions | |
| from ..arena import Arena, TooManyInvalidActions | |
| from ..backends.human import HumanBackendError | |
| from ..const import AGENTREVIEW_LOGO | |
| from ..environments import PaperReview, PaperDecision | |
| # Get the ASCII art from https://patorjk.com/software/taag/#p=display&f=Big&t=Chat%20Arena | |
| color_dict = { | |
| "red": Fore.RED, | |
| "green": Fore.GREEN, | |
| "blue": Fore.BLUE, # Paper Extractor | |
| "light_red": Fore.LIGHTRED_EX, # AC | |
| "light_green": Fore.LIGHTGREEN_EX, # Author | |
| "yellow": Fore.YELLOW, # R1 | |
| "magenta": Fore.MAGENTA, # R2 | |
| "cyan": Fore.CYAN, | |
| "white": Fore.WHITE, | |
| "black": Fore.BLACK, | |
| "light_yellow": Fore.LIGHTYELLOW_EX, | |
| "light_blue": Fore.LIGHTBLUE_EX, | |
| "light_magenta": Fore.LIGHTMAGENTA_EX, | |
| "light_cyan": Fore.LIGHTCYAN_EX, | |
| "light_white": Fore.LIGHTWHITE_EX, | |
| "light_black": Fore.LIGHTBLACK_EX, | |
| } | |
| visible_colors = [ | |
| color | |
| for color in color_dict # ANSI_COLOR_NAMES.keys() | |
| if color not in ["black", "white", "red", "green"] and "grey" not in color | |
| ] | |
| try: | |
| import colorama | |
| except ImportError: | |
| raise ImportError( | |
| "Please install colorama: `pip install colorama`" | |
| ) | |
| MAX_STEPS = 20 # We should not need this parameter for paper reviews anyway | |
| # Set logging level to ERROR | |
| logging.getLogger().setLevel(logging.ERROR) | |
| class ArenaCLI: | |
| """The CLI user interface for ChatArena.""" | |
| def __init__(self, arena: Arena): | |
| self.arena = arena | |
| self.args = arena.args | |
| def launch(self, max_steps: int = None, interactive: bool = True): | |
| """Run the CLI.""" | |
| if not interactive and max_steps is None: | |
| max_steps = MAX_STEPS | |
| args = self.args | |
| console = Console() | |
| # Print ascii art | |
| timestep = self.arena.reset() | |
| console.print("🎓AgentReview Initialized!", style="bold green") | |
| env: Union[PaperReview, PaperDecision] = self.arena.environment | |
| players = self.arena.players | |
| env_desc = self.arena.global_prompt | |
| num_players = env.num_players | |
| player_colors = visible_colors[:num_players] # sample different colors for players | |
| name_to_color = dict(zip(env.player_names, player_colors)) | |
| print("name_to_color: ", name_to_color) | |
| # System and Moderator messages are printed in red | |
| name_to_color["System"] = "red" | |
| name_to_color["Moderator"] = "red" | |
| console.print( | |
| f"[bold green underline]Environment ({env.type_name}) description:[/]\n{env_desc}" | |
| ) | |
| # Print the player name, role_desc and backend_type | |
| for i, player in enumerate(players): | |
| player_name_str = f"[{player.name} ({player.backend.type_name})] Role Description:" | |
| # player_name = Text(player_name_str) | |
| # player_name.stylize(f"bold {name_to_color[player.name]} underline") | |
| # console.print(player_name) | |
| # console.print(player.role_desc) | |
| logging.info(color_dict[name_to_color[player.name]] + player_name_str + CRStyle.RESET_ALL) | |
| logging.info(color_dict[name_to_color[player.name]] + player.role_desc + CRStyle.RESET_ALL) | |
| console.print(Fore.GREEN + "\n========= Arena Start! ==========\n" + CRStyle.RESET_ALL) | |
| step = 0 | |
| while not timestep.terminal: | |
| if env.type_name == "paper_review": | |
| if env.phase_index > 4: | |
| break | |
| elif env.type_name == "paper_decision": | |
| # Phase 5: AC makes decisions | |
| if env.phase_index > 5: | |
| break | |
| else: | |
| raise NotImplementedError(f"Unknown environment type: {env.type_name}") | |
| if interactive: | |
| command = prompt( | |
| [("class:command", "command (n/r/q/s/h) > ")], | |
| style=Style.from_dict({"command": "blue"}), | |
| completer=WordCompleter( | |
| [ | |
| "next", | |
| "n", | |
| "reset", | |
| "r", | |
| "exit", | |
| "quit", | |
| "q", | |
| "help", | |
| "h", | |
| "save", | |
| "s", | |
| ] | |
| ), | |
| ) | |
| command = command.strip() | |
| if command == "help" or command == "h": | |
| console.print("Available commands:") | |
| console.print(" [bold]next or n or <Enter>[/]: next step") | |
| console.print(" [bold]exit or quit or q[/]: exit the game") | |
| console.print(" [bold]help or h[/]: print this message") | |
| console.print(" [bold]reset or r[/]: reset the game") | |
| console.print(" [bold]save or s[/]: save the history to file") | |
| continue | |
| elif command == "exit" or command == "quit" or command == "q": | |
| break | |
| elif command == "reset" or command == "r": | |
| timestep = self.arena.reset() | |
| console.print( | |
| "\n========= Arena Reset! ==========\n", style="bold green" | |
| ) | |
| continue | |
| elif command == "next" or command == "n" or command == "": | |
| pass | |
| elif command == "save" or command == "s": | |
| # Prompt to get the file path | |
| file_path = prompt( | |
| [("class:command", "save file path > ")], | |
| style=Style.from_dict({"command": "blue"}), | |
| ) | |
| file_path = file_path.strip() | |
| # Save the history to file | |
| self.arena.save_history(file_path) | |
| # Print the save success message | |
| console.print(f"History saved to {file_path}", style="bold green") | |
| else: | |
| console.print(f"Invalid command: {command}", style="bold red") | |
| continue | |
| try: | |
| timestep = self.arena.step() | |
| except HumanBackendError as e: | |
| # Handle human input and recover with the game update | |
| human_player_name = env.get_next_player() | |
| if interactive: | |
| human_input = prompt( | |
| [ | |
| ( | |
| "class:user_prompt", | |
| f"Type your input for {human_player_name}: ", | |
| ) | |
| ], | |
| style=Style.from_dict({"user_prompt": "ansicyan underline"}), | |
| ) | |
| # If not, the conversation does not stop | |
| timestep = env.step(human_player_name, human_input) | |
| else: | |
| raise e # cannot recover from this error in non-interactive mode | |
| except TooManyInvalidActions as e: | |
| # Print the error message | |
| # console.print(f"Too many invalid actions: {e}", style="bold red") | |
| print(Fore.RED + "This will be red text" + CRStyle.RESET_ALL) | |
| break | |
| # The messages that are not yet logged | |
| messages = [msg for msg in env.get_observation() if not msg.logged] | |
| # Print the new messages | |
| for msg in messages: | |
| message_str = f"[{msg.agent_name}->{msg.visible_to}]: {msg.content}" | |
| if self.args.skip_logging: | |
| console.print(color_dict[name_to_color[msg.agent_name]] + message_str + CRStyle.RESET_ALL) | |
| msg.logged = True | |
| step += 1 | |
| if max_steps is not None and step >= max_steps: | |
| break | |
| console.print("\n========= Arena Ended! ==========\n", style="bold red") | |
| if env.type_name == "paper_review": | |
| paper_id = self.arena.environment.paper_id | |
| rebuttal_dir = get_rebuttal_dir(output_dir=self.args.output_dir, | |
| paper_id=paper_id, | |
| experiment_name=self.args.experiment_name, | |
| model_name=self.args.model_name, | |
| conference=self.args.conference) | |
| os.makedirs(rebuttal_dir, exist_ok=True) | |
| path_review_history = f"{rebuttal_dir}/{paper_id}.json" | |
| if osp.exists(path_review_history): | |
| raise Exception(f"History already exists!! ({path_review_history}). There must be something wrong with " | |
| f"the path to save the history ") | |
| self.arena.save_history(path_review_history) | |
| elif env.type_name == "paper_decision": | |
| ac_decisions = load_llm_ac_decisions(output_dir=args.output_dir, | |
| conference=args.conference, | |
| model_name=args.model_name, | |
| ac_scoring_method=args.ac_scoring_method, | |
| experiment_name=args.experiment_name, | |
| num_papers_per_area_chair=args.num_papers_per_area_chair) | |
| ac_decisions += [env.ac_decisions] | |
| save_llm_ac_decisions(ac_decisions, | |
| output_dir=args.output_dir, | |
| conference=args.conference, | |
| model_name=args.model_name, | |
| ac_scoring_method=args.ac_scoring_method, | |
| experiment_name=args.experiment_name) | |