Jofthomas commited on
Commit
503cdbf
·
verified ·
1 Parent(s): 6becbb9

Update agents.py

Browse files
Files changed (1) hide show
  1. agents.py +223 -0
agents.py CHANGED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import asyncio
4
+ import random
5
+ from openai import AsyncOpenAI # Use AsyncOpenAI for async compatibility with poke-env
6
+
7
+ from poke_env.player import Player
8
+ from poke_env.environment.battle import Battle
9
+ from poke_env.environment.move import Move
10
+ from poke_env.environment.pokemon import Pokemon
11
+ from poke_env.player import Player, Observation
12
+
13
+
14
+
15
+ class OpenAIAgent(Player):
16
+ """
17
+ An AI agent for Pokemon Showdown that uses OpenAI's API
18
+ with function calling to decide its moves.
19
+ """
20
+ def __init__(self, *args, **kwargs):
21
+ super().__init__(*args, **kwargs)
22
+
23
+ # Initialize OpenAI client
24
+ api_key = os.environ["OPENAI_API_KEY"]
25
+ if not api_key:
26
+ raise ValueError("OPENAI_API_KEY environment variable not set.")
27
+ # Use AsyncOpenAI for compatibility with poke-env's async nature
28
+ self.openai_client = AsyncOpenAI(api_key=api_key)
29
+ self.model = "gpt-4o" # Or "gpt-4-turbo-preview", "gpt-4" etc.
30
+
31
+ # Define the functions OpenAI can "call"
32
+ self.functions = [
33
+ {
34
+ "name": "choose_move",
35
+ "description": "Selects and executes an available attacking or status move.",
36
+ "parameters": {
37
+ "type": "object",
38
+ "properties": {
39
+ "move_name": {
40
+ "type": "string",
41
+ "description": "The exact name of the move to use (e.g., 'Thunderbolt', 'Swords Dance'). Must be one of the available moves.",
42
+ },
43
+ },
44
+ "required": ["move_name"],
45
+ },
46
+ },
47
+ {
48
+ "name": "choose_switch",
49
+ "description": "Selects an available Pokémon from the bench to switch into.",
50
+ "parameters": {
51
+ "type": "object",
52
+ "properties": {
53
+ "pokemon_name": {
54
+ "type": "string",
55
+ "description": "The exact name of the Pokémon species to switch to (e.g., 'Pikachu', 'Charizard'). Must be one of the available switches.",
56
+ },
57
+ },
58
+ "required": ["pokemon_name"],
59
+ },
60
+ },
61
+ ]
62
+
63
+ def _format_battle_state(self, battle: Battle) -> str:
64
+ """Formats the current battle state into a string for the LLM."""
65
+
66
+ # Own active Pokemon details
67
+ active_pkmn = battle.active_pokemon
68
+ active_pkmn_info = f"Your active Pokemon: {active_pkmn.species} " \
69
+ f"(Type: {'/'.join(map(str, active_pkmn.types))}) " \
70
+ f"HP: {active_pkmn.current_hp_fraction * 100:.1f}% " \
71
+ f"Status: {active_pkmn.status.name if active_pkmn.status else 'None'} " \
72
+ f"Boosts: {active_pkmn.boosts}"
73
+
74
+ # Opponent active Pokemon details
75
+ opponent_pkmn = battle.opponent_active_pokemon
76
+ opponent_pkmn_info = f"Opponent's active Pokemon: {opponent_pkmn.species} " \
77
+ f"(Type: {'/'.join(map(str, opponent_pkmn.types))}) " \
78
+ f"HP: {opponent_pkmn.current_hp_fraction * 100:.1f}% " \
79
+ f"Status: {opponent_pkmn.status.name if opponent_pkmn.status else 'None'} " \
80
+ f"Boosts: {opponent_pkmn.boosts}"
81
+
82
+ # Available moves
83
+ available_moves_info = "Available moves:\n"
84
+ if battle.available_moves:
85
+ for move in battle.available_moves:
86
+ available_moves_info += f"- {move.id} (Type: {move.type}, BP: {move.base_power}, Acc: {move.accuracy}, PP: {move.current_pp}/{move.max_pp}, Cat: {move.category.name})\n"
87
+ else:
88
+ available_moves_info += "- None (Must switch or Struggle)\n"
89
+
90
+ # Available switches
91
+ available_switches_info = "Available switches:\n"
92
+ if battle.available_switches:
93
+ for pkmn in battle.available_switches:
94
+ available_switches_info += f"- {pkmn.species} (HP: {pkmn.current_hp_fraction * 100:.1f}%, Status: {pkmn.status.name if pkmn.status else 'None'})\n"
95
+ else:
96
+ available_switches_info += "- None\n"
97
+
98
+ # Combine information
99
+ state_str = f"{active_pkmn_info}\n" \
100
+ f"{opponent_pkmn_info}\n\n" \
101
+ f"{available_moves_info}\n" \
102
+ f"{available_switches_info}\n" \
103
+ f"Weather: {battle.weather}\n" \
104
+ f"Terrains: {battle.fields}\n" \
105
+ f"Your Side Conditions: {battle.side_conditions}\n" \
106
+ f"Opponent Side Conditions: {battle.opponent_side_conditions}\n"
107
+
108
+ return state_str.strip()
109
+
110
+ async def _get_openai_decision(self, battle_state: str) -> dict | None:
111
+ """Sends state to OpenAI and gets back the function call decision."""
112
+ system_prompt = (
113
+ "You are a skilled Pokemon battle AI. Your goal is to win the battle. "
114
+ "Based on the current battle state, decide the best action: either use an available move or switch to an available Pokémon. "
115
+ "Consider type matchups, HP, status conditions, field effects, entry hazards, and potential opponent actions. "
116
+ "Only choose actions listed as available."
117
+ )
118
+ user_prompt = f"Current Battle State:\n{battle_state}\n\nChoose the best action by calling the appropriate function ('choose_move' or 'choose_switch')."
119
+
120
+ try:
121
+ response = await self.openai_client.chat.completions.create(
122
+ model=self.model,
123
+ messages=[
124
+ {"role": "system", "content": system_prompt},
125
+ {"role": "user", "content": user_prompt},
126
+ ],
127
+ functions=self.functions,
128
+ function_call="auto", # Let the model choose which function to call
129
+ temperature=0.5, # Adjust for creativity vs consistency
130
+ )
131
+ message = response.choices[0].message
132
+ if message.function_call:
133
+ function_name = message.function_call.name
134
+ try:
135
+ arguments = json.loads(message.function_call.arguments)
136
+ return {"name": function_name, "arguments": arguments}
137
+ except json.JSONDecodeError:
138
+ print(f"Error decoding function call arguments: {message.function_call.arguments}")
139
+ return None
140
+ else:
141
+ # Model decided not to call a function (or generated text instead)
142
+ print(f"Warning: OpenAI did not return a function call. Response: {message.content}")
143
+ return None
144
+
145
+ except Exception as e:
146
+ print(f"Error during OpenAI API call: {e}")
147
+ return None
148
+
149
+ def _find_move_by_name(self, battle: Battle, move_name: str) -> Move | None:
150
+ """Finds the Move object corresponding to the given name."""
151
+ # Normalize name for comparison (lowercase, remove spaces/hyphens)
152
+ normalized_name = move_name.lower().replace(" ", "").replace("-", "")
153
+ for move in battle.available_moves:
154
+ if move.id == normalized_name: # move.id is already normalized
155
+ return move
156
+ # Fallback: try matching against the display name if ID fails (less reliable)
157
+ for move in battle.available_moves:
158
+ if move.id == move_name.lower(): # Handle cases like "U-turn" vs "uturn"
159
+ return move
160
+ if move.name.lower() == move_name.lower():
161
+ return move
162
+ return None
163
+
164
+ def _find_pokemon_by_name(self, battle: Battle, pokemon_name: str) -> Pokemon | None:
165
+ """Finds the Pokemon object corresponding to the given species name."""
166
+ # Normalize name for comparison
167
+ normalized_name = pokemon_name.lower()
168
+ for pkmn in battle.available_switches:
169
+ if pkmn.species.lower() == normalized_name:
170
+ return pkmn
171
+ return None
172
+
173
+ async def choose_move(self, battle: Battle) -> str:
174
+ """
175
+ Main decision-making function called by poke-env each turn.
176
+ """
177
+ # 1. Format battle state
178
+ battle_state_str = self._format_battle_state(battle)
179
+ # print(f"\n--- Turn {battle.turn} ---")
180
+ # print(battle_state_str) # Optional: print state for debugging
181
+
182
+ # 2. Get decision from OpenAI
183
+ decision = await self._get_openai_decision(battle_state_str)
184
+
185
+ # 3. Parse decision and create order
186
+ if decision:
187
+ function_name = decision["name"]
188
+ args = decision["arguments"]
189
+ # print(f"OpenAI Recommended: {function_name} with args {args}") # Debugging
190
+
191
+ if function_name == "choose_move":
192
+ move_name = args.get("move_name")
193
+ if move_name:
194
+ chosen_move = self._find_move_by_name(battle, move_name)
195
+ if chosen_move and chosen_move in battle.available_moves:
196
+ # print(f"Action: Using move {chosen_move.id}")
197
+ return self.create_order(chosen_move)
198
+ else:
199
+ print(f"Warning: OpenAI chose unavailable/invalid move '{move_name}'. Falling back.")
200
+ else:
201
+ print(f"Warning: OpenAI 'choose_move' called without 'move_name'. Falling back.")
202
+
203
+ elif function_name == "choose_switch":
204
+ pokemon_name = args.get("pokemon_name")
205
+ if pokemon_name:
206
+ chosen_switch = self._find_pokemon_by_name(battle, pokemon_name)
207
+ if chosen_switch and chosen_switch in battle.available_switches:
208
+ # print(f"Action: Switching to {chosen_switch.species}")
209
+ return self.create_order(chosen_switch)
210
+ else:
211
+ print(f"Warning: OpenAI chose unavailable/invalid switch '{pokemon_name}'. Falling back.")
212
+ else:
213
+ print(f"Warning: OpenAI 'choose_switch' called without 'pokemon_name'. Falling back.")
214
+
215
+ # 4. Fallback if API fails, returns invalid action, or no function call
216
+ print("Fallback: Choosing random move/switch.")
217
+ # Ensure options exist before choosing randomly
218
+ available_options = battle.available_moves + battle.available_switches
219
+ if available_options:
220
+ return self.choose_random_move(battle)
221
+ else:
222
+ # Should only happen if forced to Struggle
223
+ return self.choose_default_move(battle)