Spaces:
Running
Running
Update agents.py
Browse files
agents.py
CHANGED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import asyncio
|
| 4 |
+
import random
|
| 5 |
+
from openai import AsyncOpenAI # Use AsyncOpenAI for async compatibility with poke-env
|
| 6 |
+
|
| 7 |
+
from poke_env.player import Player
|
| 8 |
+
from poke_env.environment.battle import Battle
|
| 9 |
+
from poke_env.environment.move import Move
|
| 10 |
+
from poke_env.environment.pokemon import Pokemon
|
| 11 |
+
from poke_env.player import Player, Observation
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class OpenAIAgent(Player):
|
| 16 |
+
"""
|
| 17 |
+
An AI agent for Pokemon Showdown that uses OpenAI's API
|
| 18 |
+
with function calling to decide its moves.
|
| 19 |
+
"""
|
| 20 |
+
def __init__(self, *args, **kwargs):
|
| 21 |
+
super().__init__(*args, **kwargs)
|
| 22 |
+
|
| 23 |
+
# Initialize OpenAI client
|
| 24 |
+
api_key = os.environ["OPENAI_API_KEY"]
|
| 25 |
+
if not api_key:
|
| 26 |
+
raise ValueError("OPENAI_API_KEY environment variable not set.")
|
| 27 |
+
# Use AsyncOpenAI for compatibility with poke-env's async nature
|
| 28 |
+
self.openai_client = AsyncOpenAI(api_key=api_key)
|
| 29 |
+
self.model = "gpt-4o" # Or "gpt-4-turbo-preview", "gpt-4" etc.
|
| 30 |
+
|
| 31 |
+
# Define the functions OpenAI can "call"
|
| 32 |
+
self.functions = [
|
| 33 |
+
{
|
| 34 |
+
"name": "choose_move",
|
| 35 |
+
"description": "Selects and executes an available attacking or status move.",
|
| 36 |
+
"parameters": {
|
| 37 |
+
"type": "object",
|
| 38 |
+
"properties": {
|
| 39 |
+
"move_name": {
|
| 40 |
+
"type": "string",
|
| 41 |
+
"description": "The exact name of the move to use (e.g., 'Thunderbolt', 'Swords Dance'). Must be one of the available moves.",
|
| 42 |
+
},
|
| 43 |
+
},
|
| 44 |
+
"required": ["move_name"],
|
| 45 |
+
},
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"name": "choose_switch",
|
| 49 |
+
"description": "Selects an available Pokémon from the bench to switch into.",
|
| 50 |
+
"parameters": {
|
| 51 |
+
"type": "object",
|
| 52 |
+
"properties": {
|
| 53 |
+
"pokemon_name": {
|
| 54 |
+
"type": "string",
|
| 55 |
+
"description": "The exact name of the Pokémon species to switch to (e.g., 'Pikachu', 'Charizard'). Must be one of the available switches.",
|
| 56 |
+
},
|
| 57 |
+
},
|
| 58 |
+
"required": ["pokemon_name"],
|
| 59 |
+
},
|
| 60 |
+
},
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
def _format_battle_state(self, battle: Battle) -> str:
|
| 64 |
+
"""Formats the current battle state into a string for the LLM."""
|
| 65 |
+
|
| 66 |
+
# Own active Pokemon details
|
| 67 |
+
active_pkmn = battle.active_pokemon
|
| 68 |
+
active_pkmn_info = f"Your active Pokemon: {active_pkmn.species} " \
|
| 69 |
+
f"(Type: {'/'.join(map(str, active_pkmn.types))}) " \
|
| 70 |
+
f"HP: {active_pkmn.current_hp_fraction * 100:.1f}% " \
|
| 71 |
+
f"Status: {active_pkmn.status.name if active_pkmn.status else 'None'} " \
|
| 72 |
+
f"Boosts: {active_pkmn.boosts}"
|
| 73 |
+
|
| 74 |
+
# Opponent active Pokemon details
|
| 75 |
+
opponent_pkmn = battle.opponent_active_pokemon
|
| 76 |
+
opponent_pkmn_info = f"Opponent's active Pokemon: {opponent_pkmn.species} " \
|
| 77 |
+
f"(Type: {'/'.join(map(str, opponent_pkmn.types))}) " \
|
| 78 |
+
f"HP: {opponent_pkmn.current_hp_fraction * 100:.1f}% " \
|
| 79 |
+
f"Status: {opponent_pkmn.status.name if opponent_pkmn.status else 'None'} " \
|
| 80 |
+
f"Boosts: {opponent_pkmn.boosts}"
|
| 81 |
+
|
| 82 |
+
# Available moves
|
| 83 |
+
available_moves_info = "Available moves:\n"
|
| 84 |
+
if battle.available_moves:
|
| 85 |
+
for move in battle.available_moves:
|
| 86 |
+
available_moves_info += f"- {move.id} (Type: {move.type}, BP: {move.base_power}, Acc: {move.accuracy}, PP: {move.current_pp}/{move.max_pp}, Cat: {move.category.name})\n"
|
| 87 |
+
else:
|
| 88 |
+
available_moves_info += "- None (Must switch or Struggle)\n"
|
| 89 |
+
|
| 90 |
+
# Available switches
|
| 91 |
+
available_switches_info = "Available switches:\n"
|
| 92 |
+
if battle.available_switches:
|
| 93 |
+
for pkmn in battle.available_switches:
|
| 94 |
+
available_switches_info += f"- {pkmn.species} (HP: {pkmn.current_hp_fraction * 100:.1f}%, Status: {pkmn.status.name if pkmn.status else 'None'})\n"
|
| 95 |
+
else:
|
| 96 |
+
available_switches_info += "- None\n"
|
| 97 |
+
|
| 98 |
+
# Combine information
|
| 99 |
+
state_str = f"{active_pkmn_info}\n" \
|
| 100 |
+
f"{opponent_pkmn_info}\n\n" \
|
| 101 |
+
f"{available_moves_info}\n" \
|
| 102 |
+
f"{available_switches_info}\n" \
|
| 103 |
+
f"Weather: {battle.weather}\n" \
|
| 104 |
+
f"Terrains: {battle.fields}\n" \
|
| 105 |
+
f"Your Side Conditions: {battle.side_conditions}\n" \
|
| 106 |
+
f"Opponent Side Conditions: {battle.opponent_side_conditions}\n"
|
| 107 |
+
|
| 108 |
+
return state_str.strip()
|
| 109 |
+
|
| 110 |
+
async def _get_openai_decision(self, battle_state: str) -> dict | None:
|
| 111 |
+
"""Sends state to OpenAI and gets back the function call decision."""
|
| 112 |
+
system_prompt = (
|
| 113 |
+
"You are a skilled Pokemon battle AI. Your goal is to win the battle. "
|
| 114 |
+
"Based on the current battle state, decide the best action: either use an available move or switch to an available Pokémon. "
|
| 115 |
+
"Consider type matchups, HP, status conditions, field effects, entry hazards, and potential opponent actions. "
|
| 116 |
+
"Only choose actions listed as available."
|
| 117 |
+
)
|
| 118 |
+
user_prompt = f"Current Battle State:\n{battle_state}\n\nChoose the best action by calling the appropriate function ('choose_move' or 'choose_switch')."
|
| 119 |
+
|
| 120 |
+
try:
|
| 121 |
+
response = await self.openai_client.chat.completions.create(
|
| 122 |
+
model=self.model,
|
| 123 |
+
messages=[
|
| 124 |
+
{"role": "system", "content": system_prompt},
|
| 125 |
+
{"role": "user", "content": user_prompt},
|
| 126 |
+
],
|
| 127 |
+
functions=self.functions,
|
| 128 |
+
function_call="auto", # Let the model choose which function to call
|
| 129 |
+
temperature=0.5, # Adjust for creativity vs consistency
|
| 130 |
+
)
|
| 131 |
+
message = response.choices[0].message
|
| 132 |
+
if message.function_call:
|
| 133 |
+
function_name = message.function_call.name
|
| 134 |
+
try:
|
| 135 |
+
arguments = json.loads(message.function_call.arguments)
|
| 136 |
+
return {"name": function_name, "arguments": arguments}
|
| 137 |
+
except json.JSONDecodeError:
|
| 138 |
+
print(f"Error decoding function call arguments: {message.function_call.arguments}")
|
| 139 |
+
return None
|
| 140 |
+
else:
|
| 141 |
+
# Model decided not to call a function (or generated text instead)
|
| 142 |
+
print(f"Warning: OpenAI did not return a function call. Response: {message.content}")
|
| 143 |
+
return None
|
| 144 |
+
|
| 145 |
+
except Exception as e:
|
| 146 |
+
print(f"Error during OpenAI API call: {e}")
|
| 147 |
+
return None
|
| 148 |
+
|
| 149 |
+
def _find_move_by_name(self, battle: Battle, move_name: str) -> Move | None:
|
| 150 |
+
"""Finds the Move object corresponding to the given name."""
|
| 151 |
+
# Normalize name for comparison (lowercase, remove spaces/hyphens)
|
| 152 |
+
normalized_name = move_name.lower().replace(" ", "").replace("-", "")
|
| 153 |
+
for move in battle.available_moves:
|
| 154 |
+
if move.id == normalized_name: # move.id is already normalized
|
| 155 |
+
return move
|
| 156 |
+
# Fallback: try matching against the display name if ID fails (less reliable)
|
| 157 |
+
for move in battle.available_moves:
|
| 158 |
+
if move.id == move_name.lower(): # Handle cases like "U-turn" vs "uturn"
|
| 159 |
+
return move
|
| 160 |
+
if move.name.lower() == move_name.lower():
|
| 161 |
+
return move
|
| 162 |
+
return None
|
| 163 |
+
|
| 164 |
+
def _find_pokemon_by_name(self, battle: Battle, pokemon_name: str) -> Pokemon | None:
|
| 165 |
+
"""Finds the Pokemon object corresponding to the given species name."""
|
| 166 |
+
# Normalize name for comparison
|
| 167 |
+
normalized_name = pokemon_name.lower()
|
| 168 |
+
for pkmn in battle.available_switches:
|
| 169 |
+
if pkmn.species.lower() == normalized_name:
|
| 170 |
+
return pkmn
|
| 171 |
+
return None
|
| 172 |
+
|
| 173 |
+
async def choose_move(self, battle: Battle) -> str:
|
| 174 |
+
"""
|
| 175 |
+
Main decision-making function called by poke-env each turn.
|
| 176 |
+
"""
|
| 177 |
+
# 1. Format battle state
|
| 178 |
+
battle_state_str = self._format_battle_state(battle)
|
| 179 |
+
# print(f"\n--- Turn {battle.turn} ---")
|
| 180 |
+
# print(battle_state_str) # Optional: print state for debugging
|
| 181 |
+
|
| 182 |
+
# 2. Get decision from OpenAI
|
| 183 |
+
decision = await self._get_openai_decision(battle_state_str)
|
| 184 |
+
|
| 185 |
+
# 3. Parse decision and create order
|
| 186 |
+
if decision:
|
| 187 |
+
function_name = decision["name"]
|
| 188 |
+
args = decision["arguments"]
|
| 189 |
+
# print(f"OpenAI Recommended: {function_name} with args {args}") # Debugging
|
| 190 |
+
|
| 191 |
+
if function_name == "choose_move":
|
| 192 |
+
move_name = args.get("move_name")
|
| 193 |
+
if move_name:
|
| 194 |
+
chosen_move = self._find_move_by_name(battle, move_name)
|
| 195 |
+
if chosen_move and chosen_move in battle.available_moves:
|
| 196 |
+
# print(f"Action: Using move {chosen_move.id}")
|
| 197 |
+
return self.create_order(chosen_move)
|
| 198 |
+
else:
|
| 199 |
+
print(f"Warning: OpenAI chose unavailable/invalid move '{move_name}'. Falling back.")
|
| 200 |
+
else:
|
| 201 |
+
print(f"Warning: OpenAI 'choose_move' called without 'move_name'. Falling back.")
|
| 202 |
+
|
| 203 |
+
elif function_name == "choose_switch":
|
| 204 |
+
pokemon_name = args.get("pokemon_name")
|
| 205 |
+
if pokemon_name:
|
| 206 |
+
chosen_switch = self._find_pokemon_by_name(battle, pokemon_name)
|
| 207 |
+
if chosen_switch and chosen_switch in battle.available_switches:
|
| 208 |
+
# print(f"Action: Switching to {chosen_switch.species}")
|
| 209 |
+
return self.create_order(chosen_switch)
|
| 210 |
+
else:
|
| 211 |
+
print(f"Warning: OpenAI chose unavailable/invalid switch '{pokemon_name}'. Falling back.")
|
| 212 |
+
else:
|
| 213 |
+
print(f"Warning: OpenAI 'choose_switch' called without 'pokemon_name'. Falling back.")
|
| 214 |
+
|
| 215 |
+
# 4. Fallback if API fails, returns invalid action, or no function call
|
| 216 |
+
print("Fallback: Choosing random move/switch.")
|
| 217 |
+
# Ensure options exist before choosing randomly
|
| 218 |
+
available_options = battle.available_moves + battle.available_switches
|
| 219 |
+
if available_options:
|
| 220 |
+
return self.choose_random_move(battle)
|
| 221 |
+
else:
|
| 222 |
+
# Should only happen if forced to Struggle
|
| 223 |
+
return self.choose_default_move(battle)
|