Spaces:
Running
Running
Update agents.py
Browse files
agents.py
CHANGED
|
@@ -56,6 +56,44 @@ STANDARD_TOOL_SCHEMA = {
|
|
| 56 |
},
|
| 57 |
}
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
class LLMAgentBase(Player):
|
| 61 |
def __init__(self, *args, **kwargs):
|
|
@@ -202,26 +240,15 @@ class GeminiAgent(LLMAgentBase):
|
|
| 202 |
if not used_api_key:
|
| 203 |
raise ValueError("Google API key not provided or found in GOOGLE_API_KEY env var.")
|
| 204 |
|
| 205 |
-
# Initialize Gemini client
|
| 206 |
-
genai.
|
| 207 |
|
| 208 |
-
# Configure the
|
| 209 |
-
self.
|
| 210 |
-
{
|
| 211 |
-
"function_declarations": list(self.standard_tools.values())
|
| 212 |
-
}
|
| 213 |
-
]
|
| 214 |
-
|
| 215 |
-
# Initialize the model
|
| 216 |
-
self.model = genai.GenerativeModel(
|
| 217 |
-
model_name=self.model_name,
|
| 218 |
-
tools=self.gemini_tool_config
|
| 219 |
-
)
|
| 220 |
|
| 221 |
async def _get_llm_decision(self, battle_state: str) -> Dict[str, Any]:
|
| 222 |
"""Sends state to the Gemini API and gets back the function call decision."""
|
| 223 |
prompt = (
|
| 224 |
-
"You are a skilled Pokemon battle AI. Your goal is to win the battle. "
|
| 225 |
"Based on the current battle state, decide the best action: either use an available move or switch to an available Pokémon. "
|
| 226 |
"Consider type matchups, HP, status conditions, field effects, entry hazards, and potential opponent actions. "
|
| 227 |
"Only choose actions listed as available using their exact ID (for moves) or species name (for switches). "
|
|
@@ -231,49 +258,40 @@ class GeminiAgent(LLMAgentBase):
|
|
| 231 |
)
|
| 232 |
|
| 233 |
try:
|
| 234 |
-
#
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
)
|
| 239 |
-
print("GEMINI RESPONSE : ",response)
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
function_name = fc.name
|
| 261 |
-
# Convert arguments to dict
|
| 262 |
-
arguments = {}
|
| 263 |
-
if fc.args:
|
| 264 |
-
arguments = {k: v for k, v in fc.args.items()}
|
| 265 |
-
|
| 266 |
-
if function_name in self.standard_tools:
|
| 267 |
-
return {"decision": {"name": function_name, "arguments": arguments}}
|
| 268 |
-
else:
|
| 269 |
-
return {"error": f"Model called unknown function '{function_name}'. Args: {arguments}"}
|
| 270 |
|
| 271 |
-
#
|
| 272 |
-
|
| 273 |
-
part.text if hasattr(part, 'text') else str(part)
|
| 274 |
-
for part in candidate.content.parts
|
| 275 |
-
])
|
| 276 |
-
return {"error": f"Gemini did not return a function call. Response: {text_content[:100]}..."}
|
| 277 |
|
| 278 |
except Exception as e:
|
| 279 |
print(f"Unexpected error during Gemini processing: {e}")
|
|
@@ -293,8 +311,8 @@ class OpenAIAgent(LLMAgentBase):
|
|
| 293 |
raise ValueError("OpenAI API key not provided or found in OPENAI_API_KEY env var.")
|
| 294 |
self.openai_client = AsyncOpenAI(api_key=used_api_key)
|
| 295 |
|
| 296 |
-
#
|
| 297 |
-
self.openai_tools = list(
|
| 298 |
|
| 299 |
async def _get_llm_decision(self, battle_state: str) -> Dict[str, Any]:
|
| 300 |
system_prompt = (
|
|
@@ -354,8 +372,17 @@ class MistralAgent(LLMAgentBase):
|
|
| 354 |
raise ValueError("Mistral API key not provided or found in MISTRAL_API_KEY env var.")
|
| 355 |
self.mistral_client = Mistral(api_key=used_api_key)
|
| 356 |
|
| 357 |
-
# Convert standard schema to Mistral's tool format
|
| 358 |
-
self.mistral_tools =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
|
| 360 |
async def _get_llm_decision(self, battle_state: str) -> Dict[str, Any]:
|
| 361 |
system_prompt = (
|
|
@@ -368,23 +395,29 @@ class MistralAgent(LLMAgentBase):
|
|
| 368 |
user_prompt = f"Current Battle State:\n{battle_state}\n\nChoose the best action by calling the appropriate function ('choose_move' or 'choose_switch')."
|
| 369 |
|
| 370 |
try:
|
| 371 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
model=self.model,
|
| 373 |
-
messages=
|
| 374 |
-
{"role": "system", "content": system_prompt},
|
| 375 |
-
{"role": "user", "content": user_prompt}
|
| 376 |
-
],
|
| 377 |
tools=self.mistral_tools,
|
| 378 |
-
tool_choice="
|
| 379 |
-
temperature=0.
|
| 380 |
)
|
| 381 |
-
print("Mistral RESPONSE : ",response)
|
| 382 |
-
|
| 383 |
# Check for tool calls in the response
|
| 384 |
-
|
|
|
|
| 385 |
tool_call = message.tool_calls[0] # Get the first tool call
|
| 386 |
function_name = tool_call.function.name
|
| 387 |
try:
|
|
|
|
| 388 |
arguments = json.loads(tool_call.function.arguments or '{}')
|
| 389 |
if function_name in self.standard_tools:
|
| 390 |
return {"decision": {"name": function_name, "arguments": arguments}}
|
|
@@ -393,9 +426,11 @@ class MistralAgent(LLMAgentBase):
|
|
| 393 |
except json.JSONDecodeError:
|
| 394 |
return {"error": f"Error decoding function arguments: {tool_call.function.arguments}"}
|
| 395 |
else:
|
| 396 |
-
# Model
|
| 397 |
return {"error": f"Mistral did not return a tool call. Response: {message.content}"}
|
| 398 |
|
| 399 |
except Exception as e:
|
| 400 |
print(f"Error during Mistral API call: {e}")
|
|
|
|
|
|
|
| 401 |
return {"error": f"Unexpected error: {str(e)}"}
|
|
|
|
| 56 |
},
|
| 57 |
}
|
| 58 |
|
| 59 |
+
# --- OpenAI Tools Schema (with 'type' field) ---
|
| 60 |
+
OPENAI_TOOL_SCHEMA = {
|
| 61 |
+
"choose_move": {
|
| 62 |
+
"type": "function",
|
| 63 |
+
"function": {
|
| 64 |
+
"name": "choose_move",
|
| 65 |
+
"description": "Selects and executes an available attacking or status move.",
|
| 66 |
+
"parameters": {
|
| 67 |
+
"type": "object",
|
| 68 |
+
"properties": {
|
| 69 |
+
"move_name": {
|
| 70 |
+
"type": "string",
|
| 71 |
+
"description": "The exact name or ID (e.g., 'thunderbolt', 'swordsdance') of the move to use. Must be one of the available moves.",
|
| 72 |
+
},
|
| 73 |
+
},
|
| 74 |
+
"required": ["move_name"],
|
| 75 |
+
},
|
| 76 |
+
}
|
| 77 |
+
},
|
| 78 |
+
"choose_switch": {
|
| 79 |
+
"type": "function",
|
| 80 |
+
"function": {
|
| 81 |
+
"name": "choose_switch",
|
| 82 |
+
"description": "Selects an available Pokémon from the bench to switch into.",
|
| 83 |
+
"parameters": {
|
| 84 |
+
"type": "object",
|
| 85 |
+
"properties": {
|
| 86 |
+
"pokemon_name": {
|
| 87 |
+
"type": "string",
|
| 88 |
+
"description": "The exact name of the Pokémon species to switch to (e.g., 'Pikachu', 'Charizard'). Must be one of the available switches.",
|
| 89 |
+
},
|
| 90 |
+
},
|
| 91 |
+
"required": ["pokemon_name"],
|
| 92 |
+
},
|
| 93 |
+
}
|
| 94 |
+
},
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
|
| 98 |
class LLMAgentBase(Player):
|
| 99 |
def __init__(self, *args, **kwargs):
|
|
|
|
| 240 |
if not used_api_key:
|
| 241 |
raise ValueError("Google API key not provided or found in GOOGLE_API_KEY env var.")
|
| 242 |
|
| 243 |
+
# Initialize Gemini client using the correct API
|
| 244 |
+
self.genai_client = genai.Client(api_key=used_api_key)
|
| 245 |
|
| 246 |
+
# Configure the tools for function calling
|
| 247 |
+
self.function_declarations = list(self.standard_tools.values())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
|
| 249 |
async def _get_llm_decision(self, battle_state: str) -> Dict[str, Any]:
|
| 250 |
"""Sends state to the Gemini API and gets back the function call decision."""
|
| 251 |
prompt = (
|
|
|
|
| 252 |
"Based on the current battle state, decide the best action: either use an available move or switch to an available Pokémon. "
|
| 253 |
"Consider type matchups, HP, status conditions, field effects, entry hazards, and potential opponent actions. "
|
| 254 |
"Only choose actions listed as available using their exact ID (for moves) or species name (for switches). "
|
|
|
|
| 258 |
)
|
| 259 |
|
| 260 |
try:
|
| 261 |
+
# Configure tools using the Gemini API format
|
| 262 |
+
tools = genai.types.Tool(function_declarations=self.function_declarations)
|
| 263 |
+
config = genai.types.GenerateContentConfig(tools=[tools])
|
| 264 |
+
|
| 265 |
+
# Send request to the model
|
| 266 |
+
response = self.genai_client.models.generate_content(
|
| 267 |
+
model=self.model_name,
|
| 268 |
+
contents=prompt,
|
| 269 |
+
config=config
|
| 270 |
)
|
| 271 |
+
print("GEMINI RESPONSE : ", response)
|
| 272 |
+
|
| 273 |
+
# Check for function calls in the response
|
| 274 |
+
if (hasattr(response, 'candidates') and
|
| 275 |
+
response.candidates and
|
| 276 |
+
hasattr(response.candidates[0], 'content') and
|
| 277 |
+
hasattr(response.candidates[0].content, 'parts') and
|
| 278 |
+
response.candidates[0].content.parts and
|
| 279 |
+
hasattr(response.candidates[0].content.parts[0], 'function_call')):
|
| 280 |
+
|
| 281 |
+
function_call = response.candidates[0].content.parts[0].function_call
|
| 282 |
+
function_name = function_call.name
|
| 283 |
+
# Get arguments
|
| 284 |
+
arguments = {}
|
| 285 |
+
if hasattr(function_call, 'args'):
|
| 286 |
+
arguments = function_call.args
|
| 287 |
+
|
| 288 |
+
if function_name in self.standard_tools:
|
| 289 |
+
return {"decision": {"name": function_name, "arguments": arguments}}
|
| 290 |
+
else:
|
| 291 |
+
return {"error": f"Model called unknown function '{function_name}'."}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
|
| 293 |
+
# No function call found
|
| 294 |
+
return {"error": "Gemini did not return a function call."}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
|
| 296 |
except Exception as e:
|
| 297 |
print(f"Unexpected error during Gemini processing: {e}")
|
|
|
|
| 311 |
raise ValueError("OpenAI API key not provided or found in OPENAI_API_KEY env var.")
|
| 312 |
self.openai_client = AsyncOpenAI(api_key=used_api_key)
|
| 313 |
|
| 314 |
+
# Use the OpenAI-specific schema with type field
|
| 315 |
+
self.openai_tools = list(OPENAI_TOOL_SCHEMA.values())
|
| 316 |
|
| 317 |
async def _get_llm_decision(self, battle_state: str) -> Dict[str, Any]:
|
| 318 |
system_prompt = (
|
|
|
|
| 372 |
raise ValueError("Mistral API key not provided or found in MISTRAL_API_KEY env var.")
|
| 373 |
self.mistral_client = Mistral(api_key=used_api_key)
|
| 374 |
|
| 375 |
+
# Convert standard schema to Mistral's tool format with "function" wrapper
|
| 376 |
+
self.mistral_tools = []
|
| 377 |
+
for tool_name, tool_schema in self.standard_tools.items():
|
| 378 |
+
self.mistral_tools.append({
|
| 379 |
+
"type": "function",
|
| 380 |
+
"function": {
|
| 381 |
+
"name": tool_schema["name"],
|
| 382 |
+
"description": tool_schema["description"],
|
| 383 |
+
"parameters": tool_schema["parameters"]
|
| 384 |
+
}
|
| 385 |
+
})
|
| 386 |
|
| 387 |
async def _get_llm_decision(self, battle_state: str) -> Dict[str, Any]:
|
| 388 |
system_prompt = (
|
|
|
|
| 395 |
user_prompt = f"Current Battle State:\n{battle_state}\n\nChoose the best action by calling the appropriate function ('choose_move' or 'choose_switch')."
|
| 396 |
|
| 397 |
try:
|
| 398 |
+
# Create the messages array
|
| 399 |
+
messages = [
|
| 400 |
+
{"role": "system", "content": system_prompt},
|
| 401 |
+
{"role": "user", "content": user_prompt}
|
| 402 |
+
]
|
| 403 |
+
|
| 404 |
+
# Call the Mistral API with tool_choice set to "any" to force tool usage
|
| 405 |
+
response = self.mistral_client.chat.complete(
|
| 406 |
model=self.model,
|
| 407 |
+
messages=messages,
|
|
|
|
|
|
|
|
|
|
| 408 |
tools=self.mistral_tools,
|
| 409 |
+
tool_choice="any", # Force the model to use a tool
|
| 410 |
+
temperature=0.3,
|
| 411 |
)
|
| 412 |
+
print("Mistral RESPONSE : ", response)
|
| 413 |
+
|
| 414 |
# Check for tool calls in the response
|
| 415 |
+
message = response.choices[0].message
|
| 416 |
+
if hasattr(message, 'tool_calls') and message.tool_calls:
|
| 417 |
tool_call = message.tool_calls[0] # Get the first tool call
|
| 418 |
function_name = tool_call.function.name
|
| 419 |
try:
|
| 420 |
+
# Parse the function arguments from JSON string
|
| 421 |
arguments = json.loads(tool_call.function.arguments or '{}')
|
| 422 |
if function_name in self.standard_tools:
|
| 423 |
return {"decision": {"name": function_name, "arguments": arguments}}
|
|
|
|
| 426 |
except json.JSONDecodeError:
|
| 427 |
return {"error": f"Error decoding function arguments: {tool_call.function.arguments}"}
|
| 428 |
else:
|
| 429 |
+
# Model did not return a tool call
|
| 430 |
return {"error": f"Mistral did not return a tool call. Response: {message.content}"}
|
| 431 |
|
| 432 |
except Exception as e:
|
| 433 |
print(f"Error during Mistral API call: {e}")
|
| 434 |
+
import traceback
|
| 435 |
+
traceback.print_exc()
|
| 436 |
return {"error": f"Unexpected error: {str(e)}"}
|