Metropolis-Chess-Club / models /gemini_api.py
Forkei's picture
Upload folder using huggingface_hub
52a4f3c verified
"""
Gemini 3.1 Flash Lite API wrapper for the Chess Master agent.
Handles:
- Structured JSON output (for agent responses)
- Tool/function calling (send_message, save_memory, set_emotion, stop)
- Error handling and retries
- Temperature and sampling configuration
- Fallback to Claude API if needed
"""
import json
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, Dict, Any, List
from datetime import datetime
import google.generativeai as genai
from google.api_core.exceptions import GoogleAPIError, DeadlineExceeded
from config.settings import GEMINI_API_KEY, GEMINI_MODEL
logger = logging.getLogger(__name__)
class GeminiClient:
"""
Gemini 3.1 Flash Lite API client for Chess Master agent.
Supports:
- Structured JSON output for consistent response parsing
- Tool use / function calling
- Automatic retries with exponential backoff
- Configurable temperature and sampling
"""
# Define tools the agent can call
TOOLS = {
"send_message": {
"description": "Send a message to the player",
"parameters": {
"content": {"type": "string", "description": "The message to send"},
"tone": {"type": "string", "description": "Tone: sharp, playful, respectful, dismissive, warm"}
}
},
"save_memory": {
"description": "Save a memory about the player or the game",
"parameters": {
"content": {"type": "string", "description": "What to remember"},
"memory_type": {
"type": "string",
"description": "Category: player_behavior, player_observation, game_context, personal_note, pattern, streak, emotional"
},
"related_player_id": {"type": "string", "description": "Optional player ID"}
}
},
"set_emotion": {
"description": "Display an emotion (for future visual representation)",
"parameters": {
"emotion": {"type": "string", "description": "One of: smirk, thoughtful, amused, focused, dismissive, respectful"}
}
},
"stop": {
"description": "Do nothing. Choose silence.",
"parameters": {}
}
}
def __init__(
self,
api_key: Optional[str] = None,
model: Optional[str] = None,
temperature: float = 0.8,
max_tokens: int = 1024,
max_retries: int = 3,
):
"""
Initialize Gemini client.
Args:
api_key: Gemini API key (defaults to GEMINI_API_KEY from config)
model: Model name (defaults to GEMINI_MODEL from config)
temperature: Sampling temperature (0.0-1.0, higher = more creative)
max_tokens: Maximum tokens in response
max_retries: Number of retries on failure
"""
self.api_key = api_key or GEMINI_API_KEY
self.model = model or GEMINI_MODEL
self.temperature = temperature
self.max_tokens = max_tokens
self.max_retries = max_retries
if not self.api_key:
raise ValueError("GEMINI_API_KEY not set in environment or config")
genai.configure(api_key=self.api_key)
self.client = genai.GenerativeModel(self.model)
self._executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="gemini")
self.call_count = 0
self.error_count = 0
async def respond(
self,
system_prompt: str,
user_prompt: str,
player_id: Optional[str] = None,
) -> Dict[str, Any]:
"""
Generate a response from Gemini with structured JSON output.
Args:
system_prompt: System context and personality
user_prompt: User input/game context
player_id: Optional player ID for context
Returns:
Parsed JSON response with action, content, tone, etc.
Example response:
{
"thinking": "optional reasoning",
"action": "send_message",
"content": "Your move. I'm waiting.",
"tone": "sharp",
"metadata": {"memory_to_save": {...}}
}
"""
self.call_count += 1
full_prompt = self._build_full_prompt(system_prompt, user_prompt)
for attempt in range(self.max_retries):
try:
logger.debug(f"Gemini API call #{self.call_count}, attempt {attempt + 1}")
# Call Gemini with JSON output format
response = await self._call_gemini(full_prompt)
# Parse response
parsed = self._parse_response(response)
logger.info(f"Gemini response: action={parsed.get('action')}")
return parsed
except (DeadlineExceeded, GoogleAPIError) as e:
self.error_count += 1
logger.warning(f"Gemini API error (attempt {attempt + 1}/{self.max_retries}): {e}")
if attempt < self.max_retries - 1:
# Exponential backoff
wait_time = 2 ** attempt
logger.debug(f"Retrying in {wait_time}s...")
await asyncio.sleep(wait_time)
else:
logger.error(f"Failed after {self.max_retries} attempts")
raise
except json.JSONDecodeError as e:
self.error_count += 1
logger.error(f"Failed to parse JSON response: {e}")
raise
# Should not reach here
raise RuntimeError("All retries exhausted")
async def _call_gemini(self, prompt: str) -> str:
"""
Call Gemini API synchronously (wrapped in async).
Gemini doesn't have native async support, so we run it in a thread pool.
"""
loop = asyncio.get_running_loop()
return await loop.run_in_executor(
self._executor,
lambda: self.client.generate_content(prompt).text
)
def _build_full_prompt(self, system_prompt: str, user_prompt: str) -> str:
"""
Combine system and user prompts with JSON instruction.
"""
return f"""{system_prompt}
---
{user_prompt}
---
IMPORTANT: You must respond ONLY with valid JSON. No markdown, no extra text.
The JSON must have this structure:
{{
"thinking": "optional pre-response reasoning",
"action": "send_message | stop | save_memory | set_emotion",
"content": "the message, memory content, or emotion",
"tone": "optional tone indicator",
"metadata": {{optional additional context}}
}}
Required fields: action, content
Optional fields: thinking, tone, metadata
Examples:
{{"action": "send_message", "content": "Nice move.", "tone": "respectful"}}
{{"action": "save_memory", "content": "Alice always plays the Sicilian", "memory_type": "player_behavior"}}
{{"action": "stop"}}
"""
def _parse_response(self, response_text: str) -> Dict[str, Any]:
"""
Parse Gemini's JSON response.
Handles:
- Extracting JSON from markdown code blocks (if present)
- Validating required fields
- Type conversion
"""
text = response_text.strip()
# Try to extract JSON from markdown code block
if text.startswith("```"):
# Find the JSON block
lines = text.split("\n")
json_lines = [l for l in lines[1:] if l and not l.startswith("```")]
text = "\n".join(json_lines)
# Parse JSON
try:
parsed = json.loads(text)
except json.JSONDecodeError:
logger.error(f"Invalid JSON response: {text[:200]}")
raise
# Validate required fields
if "action" not in parsed:
raise ValueError(f"Missing 'action' field in response: {parsed}")
if "content" not in parsed and parsed.get("action") != "stop":
raise ValueError(f"Missing 'content' field in response: {parsed}")
# Ensure action is valid
valid_actions = ["send_message", "stop", "save_memory", "set_emotion"]
if parsed["action"] not in valid_actions:
raise ValueError(f"Invalid action '{parsed['action']}'. Must be one of: {valid_actions}")
# Set defaults
if "tone" not in parsed:
parsed["tone"] = None
if "metadata" not in parsed:
parsed["metadata"] = {}
return parsed
def get_stats(self) -> Dict[str, Any]:
"""Get API call statistics."""
return {
"total_calls": self.call_count,
"total_errors": self.error_count,
"error_rate": self.error_count / self.call_count if self.call_count > 0 else 0,
}
def reset_stats(self) -> None:
"""Reset API statistics."""
self.call_count = 0
self.error_count = 0
# Export for convenience
__all__ = ["GeminiClient"]