Spaces:
Sleeping
Sleeping
| # agent.py | |
| import os | |
| import random | |
| from typing import TypedDict, Literal, List, Optional | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage | |
| from langchain_core.tools import BaseTool | |
| # Import necessary poke-env components for type hinting and functionality | |
| from poke_env.player import Player | |
| from poke_env.environment.battle import Battle | |
| from poke_env.environment.move import Move | |
| from poke_env.environment.pokemon import Pokemon | |
| from tools import tools | |
| from langgraph.graph import StateGraph, END | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # Define the state for our LangGraph | |
| class GraphState(TypedDict): | |
| battle_state_str: str | |
| messages: List[BaseMessage] # For conversation history with the LLM | |
| generation: Optional[AIMessage] # LLM's direct response | |
| action_to_take: Optional[dict] # Parsed function call from LLM | |
| # Store the chosen action (move ID or switch species) as a string | |
| chosen_action_id: Optional[str] | |
| error_message: Optional[str] # To store any errors during processing | |
| fallback_triggered: bool # Flag if fallback logic is used | |
| agent_name: Optional[str] # To track the agent's name for logging | |
| class OpenAIAgent(Player): | |
| """ | |
| An AI agent for Pokemon Showdown that uses LangGraph | |
| with function calling to decide its moves. | |
| Requires OPENAI_API_KEY environment variable to be set. | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| api_key = os.environ.get("OPENAI_API_KEY") | |
| if not api_key: | |
| raise ValueError("OPENAI_API_KEY environment variable not set or loaded.") | |
| self.model_name = "gpt-4o" | |
| # Initialize ChatOpenAI model | |
| self.llm = ChatOpenAI(api_key=api_key, model=self.model_name, temperature=0.5) | |
| # Bind the tools to the LLM | |
| self.llm_with_tools = self.llm.bind_tools(tools) | |
| self.battle_history = [] # Optional: To potentially add context later | |
| # Build the LangGraph | |
| self.graph = self._build_graph() | |
| self.checkpointer = MemorySaver() # In-memory checkpointer for simplicity | |
| # Compile the graph without the checkpointer to see if it resolves the generator error | |
| self.compiled_graph = self.graph.compile(checkpointer=self.checkpointer) | |
| def _format_battle_state(self, battle: Battle) -> str: | |
| """Formats the current battle state into a string for the LLM.""" | |
| # Own active Pokemon details | |
| active_pkmn = battle.active_pokemon | |
| active_pkmn_info = f"Your active Pokemon: {active_pkmn.species} " \ | |
| f"(Type: {'/'.join(map(str, active_pkmn.types))}) " \ | |
| f"HP: {active_pkmn.current_hp_fraction * 100:.1f}% " \ | |
| f"Status: {active_pkmn.status.name if active_pkmn.status else 'None'} " \ | |
| f"Boosts: {active_pkmn.boosts}" | |
| # Opponent active Pokemon details | |
| opponent_pkmn = battle.opponent_active_pokemon | |
| opponent_pkmn_info = f"Opponent's active Pokemon: {opponent_pkmn.species} " \ | |
| f"(Type: {'/'.join(map(str, opponent_pkmn.types))}) " \ | |
| f"HP: {opponent_pkmn.current_hp_fraction * 100:.1f}% " \ | |
| f"Status: {opponent_pkmn.status.name if opponent_pkmn.status else 'None'} " \ | |
| f"Boosts: {opponent_pkmn.boosts}" | |
| # Available moves | |
| available_moves_info = "Available moves:\n" | |
| if battle.available_moves: | |
| for move in battle.available_moves: | |
| available_moves_info += f"- {move.id} (Type: {move.type}, BP: {move.base_power}, Acc: {move.accuracy}, PP: {move.current_pp}/{move.max_pp}, Cat: {move.category.name})\n" | |
| else: | |
| available_moves_info += "- None (Must switch or Struggle)\n" | |
| # Available switches | |
| available_switches_info = "Available switches:\n" | |
| if battle.available_switches: | |
| for pkmn in battle.available_switches: | |
| available_switches_info += f"- {pkmn.species} (HP: {pkmn.current_hp_fraction * 100:.1f}%, Status: {pkmn.status.name if pkmn.status else 'None'})\n" | |
| else: | |
| available_switches_info += "- None\n" | |
| # Combine information | |
| state_str = f"{active_pkmn_info}\n" \ | |
| f"{opponent_pkmn_info}\n\n" \ | |
| f"{available_moves_info}\n" \ | |
| f"{available_switches_info}\n" \ | |
| f"Weather: {battle.weather}\n" \ | |
| f"Terrains: {battle.fields}\n" \ | |
| f"Your Side Conditions: {battle.side_conditions}\n" \ | |
| f"Opponent Side Conditions: {battle.opponent_side_conditions}\n" | |
| return state_str.strip() | |
| # --- LangGraph Node Definitions --- | |
| async def _get_llm_decision_node(self, state: GraphState) -> GraphState: | |
| """Invokes the LLM with the current battle state and tools.""" | |
| battle_state_str = state["battle_state_str"] | |
| system_prompt = ( | |
| "You are a skilled Pokemon battle AI. Your goal is to win the battle. " | |
| "Based on the current battle state, decide the best action: either use an available move or switch to an available Pokémon. " | |
| "Consider type matchups, HP, status conditions, field effects, entry hazards, and potential opponent actions. " | |
| "Only choose actions listed as available by calling the appropriate tool ('choose_move' or 'choose_switch')." | |
| ) | |
| user_prompt = f"Current Battle State:\\n{battle_state_str}\\n\\nChoose the best action." | |
| messages = [ | |
| SystemMessage(content=system_prompt), | |
| HumanMessage(content=user_prompt) | |
| ] | |
| try: | |
| generation = await self.llm_with_tools.ainvoke(messages) | |
| return {"messages": messages + [generation], "generation": generation, "action_to_take": None, "error_message": None, "fallback_triggered": False} | |
| except Exception as e: | |
| print(f"Error during LLM call: {e}") | |
| return {"messages": messages, "generation": None, "action_to_take": None, "error_message": str(e), "fallback_triggered": True} | |
| def _process_llm_tool_call_node(self, state: GraphState, config: dict) -> GraphState: | |
| """Processes the LLM's tool call decision and determines the chosen action ID.""" | |
| battle = self.current_battle # Access battle from instance variable | |
| generation = state.get("generation") | |
| action_to_take = None | |
| chosen_action_id = None # Store the chosen move ID or switch species | |
| error_message = state.get("error_message") | |
| fallback_triggered = state.get("fallback_triggered", False) | |
| if generation and generation.tool_calls: | |
| tool_call = generation.tool_calls[0] # Assuming one tool call for now | |
| function_name = tool_call["name"] | |
| args = tool_call["args"] | |
| # print(f"LLM Recommended Tool: {function_name} with args {args}") # Debugging | |
| action_to_take = {"name": function_name, "arguments": args} | |
| if function_name == "choose_move": | |
| move_name = args.get("move_name") | |
| if move_name: | |
| chosen_move = self._find_move_by_name(battle, move_name) | |
| if chosen_move and chosen_move in battle.available_moves: | |
| chosen_action_id = chosen_move.id # Store move ID | |
| else: | |
| error_message = f"LLM chose unavailable/invalid move '{move_name}'." | |
| fallback_triggered = True | |
| else: | |
| error_message = "LLM 'choose_move' called without 'move_name'." | |
| fallback_triggered = True | |
| elif function_name == "choose_switch": | |
| pokemon_name = args.get("pokemon_name") | |
| if pokemon_name: | |
| chosen_switch = self._find_pokemon_by_name(battle, pokemon_name) | |
| if chosen_switch and chosen_switch in battle.available_switches: | |
| chosen_action_id = chosen_switch.species # Store pokemon species | |
| else: | |
| error_message = f"LLM chose unavailable/invalid switch '{pokemon_name}'." | |
| fallback_triggered = True | |
| else: | |
| error_message = "LLM 'choose_switch' called without 'pokemon_name'." | |
| fallback_triggered = True | |
| else: | |
| error_message = f"LLM called unknown tool: {function_name}" | |
| fallback_triggered = True | |
| elif not error_message: # No tool call and no previous error | |
| error_message = "LLM did not call a tool." | |
| fallback_triggered = True | |
| if error_message and not fallback_triggered: # An error occurred but fallback not yet set | |
| fallback_triggered = True | |
| return { | |
| "action_to_take": action_to_take, | |
| "chosen_action_id": chosen_action_id, # Return the chosen action ID | |
| "error_message": error_message, | |
| "fallback_triggered": fallback_triggered | |
| } | |
| def _handle_fallback_node(self, state: GraphState, config: dict) -> GraphState: | |
| """Handles fallback logic if LLM decision fails and determines a random action ID.""" | |
| battle = self.current_battle # Access battle from instance variable | |
| print(f"Fallback triggered. Error: {state.get('error_message', 'Unknown')}. Choosing random move/switch.") | |
| available_options = battle.available_moves + battle.available_switches | |
| chosen_action_id = None | |
| if available_options: | |
| # Choose a random move or switch and get its ID/species | |
| random_choice = random.choice(available_options) | |
| if isinstance(random_choice, Move): | |
| chosen_action_id = random_choice.id | |
| elif isinstance(random_choice, Pokemon): | |
| chosen_action_id = random_choice.species | |
| # Note: If no options, chosen_action_id remains None, which will lead to default move outside graph | |
| return {"chosen_action_id": chosen_action_id, "fallback_triggered": True} | |
| # --- LangGraph Conditional Edges --- | |
| def _should_fallback(self, state: GraphState) -> Literal["fallback", "proceed"]: | |
| """Determines if fallback logic should be used.""" | |
| if state.get("fallback_triggered"): | |
| return "fallback" | |
| if state.get("error_message"): # If an error occurred in LLM or processing | |
| return "fallback" | |
| generation = state.get("generation") | |
| if not generation or not generation.tool_calls: | |
| return "fallback" # No tool call from LLM | |
| return "proceed" | |
| def _should_fallback_after_processing(self, state: GraphState) -> Literal["fallback", "END"]: | |
| """ | |
| Determines if fallback is needed after processing the LLM tool call, | |
| or if the process can end. | |
| """ | |
| if state.get("fallback_triggered") and not state.get("chosen_action_id"): | |
| # Fallback was triggered during processing, but no action ID was determined | |
| # (e.g., LLM called unknown tool, or tool call was malformed) | |
| return "fallback" | |
| # If chosen_action_id is set (by LLM or by the first fallback), or if no fallback was triggered, end. | |
| return END | |
| # --- LangGraph Definition --- | |
| def _build_graph(self) -> StateGraph: | |
| """Builds the LangGraph for decision making.""" | |
| graph = StateGraph(GraphState) | |
| graph.add_node("get_llm_decision", self._get_llm_decision_node) | |
| # Use updated node method signatures | |
| graph.add_node("process_llm_tool_call", self._process_llm_tool_call_node) | |
| graph.add_node("handle_fallback", self._handle_fallback_node) | |
| graph.set_entry_point("get_llm_decision") | |
| graph.add_conditional_edges( | |
| "get_llm_decision", # Source node: after LLM call | |
| self._should_fallback, # Condition function | |
| { | |
| "fallback": "handle_fallback", # If fallback needed | |
| "proceed": "process_llm_tool_call" # If LLM provided a tool call | |
| } | |
| ) | |
| # Add another conditional edge after processing, using the new named function | |
| graph.add_conditional_edges( | |
| "process_llm_tool_call", | |
| self._should_fallback_after_processing, # Use the named function | |
| { | |
| "fallback": "handle_fallback", | |
| END: END # If processing was successful or already led to a chosen_action_id | |
| } | |
| ) | |
| graph.add_edge("handle_fallback", END) | |
| return graph | |
| def _find_move_by_name(self, battle: Battle, move_name: str) -> Move | None: | |
| """Finds the Move object corresponding to the given name.""" | |
| # Normalize name for comparison (lowercase, remove spaces/hyphens) | |
| normalized_name = move_name.lower().replace(" ", "").replace("-", "") | |
| for move in battle.available_moves: | |
| if move.id == normalized_name: # move.id is already normalized | |
| return move | |
| # Fallback: try matching against the display name if ID fails (less reliable) | |
| for move in battle.available_moves: | |
| # Rely primarily on move.id for matching | |
| if move.id == move_name.lower().replace(" ", "").replace("-", ""): | |
| return move | |
| # Also check against the raw move_name in case it's already a valid ID | |
| if move.id == move_name.lower(): | |
| return move | |
| return None | |
| def _find_pokemon_by_name(self, battle: Battle, pokemon_name: str) -> Pokemon | None: | |
| """Finds the Pokemon object corresponding to the given species name.""" | |
| # Normalize name for comparison | |
| normalized_name = pokemon_name.lower() | |
| for pkmn in battle.available_switches: | |
| if pkmn.species.lower() == normalized_name: | |
| return pkmn | |
| return None | |
| async def choose_move(self, battle: Battle) -> str: | |
| """ | |
| Main decision-making function called by poke-env each turn. | |
| Uses LangGraph to decide the action ID, then creates the order. | |
| """ | |
| # Store the current battle object as an instance variable | |
| self.current_battle = battle | |
| # 1. Format battle state | |
| battle_state_str = self._format_battle_state(self.current_battle) | |
| # print(f"\\n--- Turn {self.current_battle.turn} ---") # Debugging | |
| # print(battle_state_str) # Debugging | |
| # 2. Prepare initial state for the graph | |
| initial_graph_state: GraphState = { | |
| "battle_state_str": battle_state_str, | |
| "messages": [], | |
| "generation": None, | |
| "action_to_take": None, | |
| "chosen_action_id": None, # Initialize chosen action ID | |
| "error_message": None, | |
| "fallback_triggered": False, | |
| "agent_name": self.username | |
| } | |
| # Configuration for the checkpointer | |
| # The thread_id is important for the checkpointer to save/load state correctly | |
| # For a single turn, we can use a unique ID like the turn number or a random ID. | |
| # If you want memory across turns for the same battle, use a consistent battle ID. | |
| config = { | |
| "configurable": { | |
| "thread_id": f"battle_{self.current_battle.battle_tag}_turn_{self.current_battle.turn}" | |
| # Removed battle from config to avoid serialization issues | |
| } | |
| } | |
| # 3. Invoke the graph to get the chosen action ID | |
| final_state = None | |
| try: | |
| # Use ainvoke for a single execution to completion, avoiding generator issues | |
| final_state = await self.compiled_graph.ainvoke(initial_graph_state, config=config) | |
| except Exception as e: | |
| print(f"Error invoking LangGraph: {e}") | |
| # If graph invocation itself fails, use a basic fallback | |
| final_state = initial_graph_state # Reset to initial | |
| final_state["fallback_triggered"] = True | |
| final_state["error_message"] = f"Graph execution error: {str(e)}" | |
| # 4. Create the final order based on the chosen action ID from the graph | |
| chosen_action_id = final_state.get("chosen_action_id") if final_state else None | |
| if chosen_action_id: | |
| # Try to find the corresponding move or pokemon | |
| chosen_move = self._find_move_by_name(self.current_battle, chosen_action_id) | |
| if chosen_move: | |
| # print(f"Graph decision: Use move {chosen_move.id}") # Debugging | |
| return self.create_order(chosen_move) | |
| chosen_switch = self._find_pokemon_by_name(self.current_battle, chosen_action_id) | |
| if chosen_switch: | |
| # print(f"Graph decision: Switch to {chosen_switch.species}") # Debugging | |
| return self.create_order(chosen_switch) | |
| # If the ID from the graph doesn't match any available move or switch (shouldn't happen if nodes are correct) | |
| print(f"Warning: Graph returned invalid action ID '{chosen_action_id}'. Falling back.") | |
| # Fallback to random if the ID is invalid | |
| return self.choose_random_move(self.current_battle) |