pokemon_agents / agents.py
cgoncalves's picture
Add .gitignore, implement LangGraph in OpenAIAgent, and create local battle script
da87fba
# agent.py
import os
import random
from typing import TypedDict, Literal, List, Optional
from langchain_openai import ChatOpenAI
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage
from langchain_core.tools import BaseTool
# Import necessary poke-env components for type hinting and functionality
from poke_env.player import Player
from poke_env.environment.battle import Battle
from poke_env.environment.move import Move
from poke_env.environment.pokemon import Pokemon
from tools import tools
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from dotenv import load_dotenv
load_dotenv()
# Define the state for our LangGraph
class GraphState(TypedDict):
battle_state_str: str
messages: List[BaseMessage] # For conversation history with the LLM
generation: Optional[AIMessage] # LLM's direct response
action_to_take: Optional[dict] # Parsed function call from LLM
# Store the chosen action (move ID or switch species) as a string
chosen_action_id: Optional[str]
error_message: Optional[str] # To store any errors during processing
fallback_triggered: bool # Flag if fallback logic is used
agent_name: Optional[str] # To track the agent's name for logging
class OpenAIAgent(Player):
"""
An AI agent for Pokemon Showdown that uses LangGraph
with function calling to decide its moves.
Requires OPENAI_API_KEY environment variable to be set.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
raise ValueError("OPENAI_API_KEY environment variable not set or loaded.")
self.model_name = "gpt-4o"
# Initialize ChatOpenAI model
self.llm = ChatOpenAI(api_key=api_key, model=self.model_name, temperature=0.5)
# Bind the tools to the LLM
self.llm_with_tools = self.llm.bind_tools(tools)
self.battle_history = [] # Optional: To potentially add context later
# Build the LangGraph
self.graph = self._build_graph()
self.checkpointer = MemorySaver() # In-memory checkpointer for simplicity
# Compile the graph without the checkpointer to see if it resolves the generator error
self.compiled_graph = self.graph.compile(checkpointer=self.checkpointer)
def _format_battle_state(self, battle: Battle) -> str:
"""Formats the current battle state into a string for the LLM."""
# Own active Pokemon details
active_pkmn = battle.active_pokemon
active_pkmn_info = f"Your active Pokemon: {active_pkmn.species} " \
f"(Type: {'/'.join(map(str, active_pkmn.types))}) " \
f"HP: {active_pkmn.current_hp_fraction * 100:.1f}% " \
f"Status: {active_pkmn.status.name if active_pkmn.status else 'None'} " \
f"Boosts: {active_pkmn.boosts}"
# Opponent active Pokemon details
opponent_pkmn = battle.opponent_active_pokemon
opponent_pkmn_info = f"Opponent's active Pokemon: {opponent_pkmn.species} " \
f"(Type: {'/'.join(map(str, opponent_pkmn.types))}) " \
f"HP: {opponent_pkmn.current_hp_fraction * 100:.1f}% " \
f"Status: {opponent_pkmn.status.name if opponent_pkmn.status else 'None'} " \
f"Boosts: {opponent_pkmn.boosts}"
# Available moves
available_moves_info = "Available moves:\n"
if battle.available_moves:
for move in battle.available_moves:
available_moves_info += f"- {move.id} (Type: {move.type}, BP: {move.base_power}, Acc: {move.accuracy}, PP: {move.current_pp}/{move.max_pp}, Cat: {move.category.name})\n"
else:
available_moves_info += "- None (Must switch or Struggle)\n"
# Available switches
available_switches_info = "Available switches:\n"
if battle.available_switches:
for pkmn in battle.available_switches:
available_switches_info += f"- {pkmn.species} (HP: {pkmn.current_hp_fraction * 100:.1f}%, Status: {pkmn.status.name if pkmn.status else 'None'})\n"
else:
available_switches_info += "- None\n"
# Combine information
state_str = f"{active_pkmn_info}\n" \
f"{opponent_pkmn_info}\n\n" \
f"{available_moves_info}\n" \
f"{available_switches_info}\n" \
f"Weather: {battle.weather}\n" \
f"Terrains: {battle.fields}\n" \
f"Your Side Conditions: {battle.side_conditions}\n" \
f"Opponent Side Conditions: {battle.opponent_side_conditions}\n"
return state_str.strip()
# --- LangGraph Node Definitions ---
async def _get_llm_decision_node(self, state: GraphState) -> GraphState:
"""Invokes the LLM with the current battle state and tools."""
battle_state_str = state["battle_state_str"]
system_prompt = (
"You are a skilled Pokemon battle AI. Your goal is to win the battle. "
"Based on the current battle state, decide the best action: either use an available move or switch to an available Pokémon. "
"Consider type matchups, HP, status conditions, field effects, entry hazards, and potential opponent actions. "
"Only choose actions listed as available by calling the appropriate tool ('choose_move' or 'choose_switch')."
)
user_prompt = f"Current Battle State:\\n{battle_state_str}\\n\\nChoose the best action."
messages = [
SystemMessage(content=system_prompt),
HumanMessage(content=user_prompt)
]
try:
generation = await self.llm_with_tools.ainvoke(messages)
return {"messages": messages + [generation], "generation": generation, "action_to_take": None, "error_message": None, "fallback_triggered": False}
except Exception as e:
print(f"Error during LLM call: {e}")
return {"messages": messages, "generation": None, "action_to_take": None, "error_message": str(e), "fallback_triggered": True}
def _process_llm_tool_call_node(self, state: GraphState, config: dict) -> GraphState:
"""Processes the LLM's tool call decision and determines the chosen action ID."""
battle = self.current_battle # Access battle from instance variable
generation = state.get("generation")
action_to_take = None
chosen_action_id = None # Store the chosen move ID or switch species
error_message = state.get("error_message")
fallback_triggered = state.get("fallback_triggered", False)
if generation and generation.tool_calls:
tool_call = generation.tool_calls[0] # Assuming one tool call for now
function_name = tool_call["name"]
args = tool_call["args"]
# print(f"LLM Recommended Tool: {function_name} with args {args}") # Debugging
action_to_take = {"name": function_name, "arguments": args}
if function_name == "choose_move":
move_name = args.get("move_name")
if move_name:
chosen_move = self._find_move_by_name(battle, move_name)
if chosen_move and chosen_move in battle.available_moves:
chosen_action_id = chosen_move.id # Store move ID
else:
error_message = f"LLM chose unavailable/invalid move '{move_name}'."
fallback_triggered = True
else:
error_message = "LLM 'choose_move' called without 'move_name'."
fallback_triggered = True
elif function_name == "choose_switch":
pokemon_name = args.get("pokemon_name")
if pokemon_name:
chosen_switch = self._find_pokemon_by_name(battle, pokemon_name)
if chosen_switch and chosen_switch in battle.available_switches:
chosen_action_id = chosen_switch.species # Store pokemon species
else:
error_message = f"LLM chose unavailable/invalid switch '{pokemon_name}'."
fallback_triggered = True
else:
error_message = "LLM 'choose_switch' called without 'pokemon_name'."
fallback_triggered = True
else:
error_message = f"LLM called unknown tool: {function_name}"
fallback_triggered = True
elif not error_message: # No tool call and no previous error
error_message = "LLM did not call a tool."
fallback_triggered = True
if error_message and not fallback_triggered: # An error occurred but fallback not yet set
fallback_triggered = True
return {
"action_to_take": action_to_take,
"chosen_action_id": chosen_action_id, # Return the chosen action ID
"error_message": error_message,
"fallback_triggered": fallback_triggered
}
def _handle_fallback_node(self, state: GraphState, config: dict) -> GraphState:
"""Handles fallback logic if LLM decision fails and determines a random action ID."""
battle = self.current_battle # Access battle from instance variable
print(f"Fallback triggered. Error: {state.get('error_message', 'Unknown')}. Choosing random move/switch.")
available_options = battle.available_moves + battle.available_switches
chosen_action_id = None
if available_options:
# Choose a random move or switch and get its ID/species
random_choice = random.choice(available_options)
if isinstance(random_choice, Move):
chosen_action_id = random_choice.id
elif isinstance(random_choice, Pokemon):
chosen_action_id = random_choice.species
# Note: If no options, chosen_action_id remains None, which will lead to default move outside graph
return {"chosen_action_id": chosen_action_id, "fallback_triggered": True}
# --- LangGraph Conditional Edges ---
def _should_fallback(self, state: GraphState) -> Literal["fallback", "proceed"]:
"""Determines if fallback logic should be used."""
if state.get("fallback_triggered"):
return "fallback"
if state.get("error_message"): # If an error occurred in LLM or processing
return "fallback"
generation = state.get("generation")
if not generation or not generation.tool_calls:
return "fallback" # No tool call from LLM
return "proceed"
def _should_fallback_after_processing(self, state: GraphState) -> Literal["fallback", "END"]:
"""
Determines if fallback is needed after processing the LLM tool call,
or if the process can end.
"""
if state.get("fallback_triggered") and not state.get("chosen_action_id"):
# Fallback was triggered during processing, but no action ID was determined
# (e.g., LLM called unknown tool, or tool call was malformed)
return "fallback"
# If chosen_action_id is set (by LLM or by the first fallback), or if no fallback was triggered, end.
return END
# --- LangGraph Definition ---
def _build_graph(self) -> StateGraph:
"""Builds the LangGraph for decision making."""
graph = StateGraph(GraphState)
graph.add_node("get_llm_decision", self._get_llm_decision_node)
# Use updated node method signatures
graph.add_node("process_llm_tool_call", self._process_llm_tool_call_node)
graph.add_node("handle_fallback", self._handle_fallback_node)
graph.set_entry_point("get_llm_decision")
graph.add_conditional_edges(
"get_llm_decision", # Source node: after LLM call
self._should_fallback, # Condition function
{
"fallback": "handle_fallback", # If fallback needed
"proceed": "process_llm_tool_call" # If LLM provided a tool call
}
)
# Add another conditional edge after processing, using the new named function
graph.add_conditional_edges(
"process_llm_tool_call",
self._should_fallback_after_processing, # Use the named function
{
"fallback": "handle_fallback",
END: END # If processing was successful or already led to a chosen_action_id
}
)
graph.add_edge("handle_fallback", END)
return graph
def _find_move_by_name(self, battle: Battle, move_name: str) -> Move | None:
"""Finds the Move object corresponding to the given name."""
# Normalize name for comparison (lowercase, remove spaces/hyphens)
normalized_name = move_name.lower().replace(" ", "").replace("-", "")
for move in battle.available_moves:
if move.id == normalized_name: # move.id is already normalized
return move
# Fallback: try matching against the display name if ID fails (less reliable)
for move in battle.available_moves:
# Rely primarily on move.id for matching
if move.id == move_name.lower().replace(" ", "").replace("-", ""):
return move
# Also check against the raw move_name in case it's already a valid ID
if move.id == move_name.lower():
return move
return None
def _find_pokemon_by_name(self, battle: Battle, pokemon_name: str) -> Pokemon | None:
"""Finds the Pokemon object corresponding to the given species name."""
# Normalize name for comparison
normalized_name = pokemon_name.lower()
for pkmn in battle.available_switches:
if pkmn.species.lower() == normalized_name:
return pkmn
return None
async def choose_move(self, battle: Battle) -> str:
"""
Main decision-making function called by poke-env each turn.
Uses LangGraph to decide the action ID, then creates the order.
"""
# Store the current battle object as an instance variable
self.current_battle = battle
# 1. Format battle state
battle_state_str = self._format_battle_state(self.current_battle)
# print(f"\\n--- Turn {self.current_battle.turn} ---") # Debugging
# print(battle_state_str) # Debugging
# 2. Prepare initial state for the graph
initial_graph_state: GraphState = {
"battle_state_str": battle_state_str,
"messages": [],
"generation": None,
"action_to_take": None,
"chosen_action_id": None, # Initialize chosen action ID
"error_message": None,
"fallback_triggered": False,
"agent_name": self.username
}
# Configuration for the checkpointer
# The thread_id is important for the checkpointer to save/load state correctly
# For a single turn, we can use a unique ID like the turn number or a random ID.
# If you want memory across turns for the same battle, use a consistent battle ID.
config = {
"configurable": {
"thread_id": f"battle_{self.current_battle.battle_tag}_turn_{self.current_battle.turn}"
# Removed battle from config to avoid serialization issues
}
}
# 3. Invoke the graph to get the chosen action ID
final_state = None
try:
# Use ainvoke for a single execution to completion, avoiding generator issues
final_state = await self.compiled_graph.ainvoke(initial_graph_state, config=config)
except Exception as e:
print(f"Error invoking LangGraph: {e}")
# If graph invocation itself fails, use a basic fallback
final_state = initial_graph_state # Reset to initial
final_state["fallback_triggered"] = True
final_state["error_message"] = f"Graph execution error: {str(e)}"
# 4. Create the final order based on the chosen action ID from the graph
chosen_action_id = final_state.get("chosen_action_id") if final_state else None
if chosen_action_id:
# Try to find the corresponding move or pokemon
chosen_move = self._find_move_by_name(self.current_battle, chosen_action_id)
if chosen_move:
# print(f"Graph decision: Use move {chosen_move.id}") # Debugging
return self.create_order(chosen_move)
chosen_switch = self._find_pokemon_by_name(self.current_battle, chosen_action_id)
if chosen_switch:
# print(f"Graph decision: Switch to {chosen_switch.species}") # Debugging
return self.create_order(chosen_switch)
# If the ID from the graph doesn't match any available move or switch (shouldn't happen if nodes are correct)
print(f"Warning: Graph returned invalid action ID '{chosen_action_id}'. Falling back.")
# Fallback to random if the ID is invalid
return self.choose_random_move(self.current_battle)