likhon saheikh
Upgrade to Gemini 2.0 Flash Thinking model
3cbd06b
import os
import json
import httpx
from typing import AsyncGenerator, Dict, Any
from src.domain.interfaces import AgentRepository
from src.domain.models import Session, Message, MessageRole
from src.infrastructure.tools.registry import tool_registry
class GeminiAgent(AgentRepository):
def __init__(self, api_key: str, model_name: str = "gemini-2.0-flash-thinking-exp-1219"):
self.api_key = api_key
self.model_name = model_name
self.base_url = "https://generativelanguage.googleapis.com/v1beta/models"
async def chat(self, session: Session, message: Message) -> AsyncGenerator[Dict[str, Any], None]:
# Convert session history to Gemini format
contents = []
for msg in session.messages:
role = "user" if msg.role == MessageRole.USER else "model"
contents.append({"role": role, "parts": [{"text": msg.content}]})
# Add current message
contents.append({"role": "user", "parts": [{"text": message.content}]})
# Get tools
tools = tool_registry.to_gemini_tools()
# Prepare request
url = f"{self.base_url}/{self.model_name}:streamGenerateContent?alt=sse&key={self.api_key}"
payload = {
"contents": contents,
"tools": tools,
"generationConfig": {
"temperature": 0.7,
"thinking_config": {"include_thoughts": True}
}
}
async with httpx.AsyncClient() as client:
try:
async with client.stream("POST", url, json=payload, timeout=60.0) as response:
if response.status_code != 200:
error_text = await response.read()
error_msg = error_text.decode()
if response.status_code == 429:
yield {"event": "error", "data": "Gemini API Quota Exceeded. Please check your usage limits or try again later."}
else:
yield {"event": "error", "data": f"API Error {response.status_code}: {error_msg}"}
return
async for line in response.aiter_lines():
if line.startswith("data: "):
data_str = line[6:]
try:
chunk = json.loads(data_str)
# Parse candidates
if "candidates" in chunk and chunk["candidates"]:
candidate = chunk["candidates"][0]
if "content" in candidate and "parts" in candidate["content"]:
for part in candidate["content"]["parts"]:
if "text" in part:
yield {"event": "message", "data": part["text"]}
if "functionCall" in part:
fc = part["functionCall"]
yield {
"event": "tool",
"data": {"name": fc["name"], "args": fc["args"]}
}
# Execute tool
tool = tool_registry.get_tool(fc["name"])
if tool:
try:
result = await tool.func(**fc["args"])
yield {
"event": "tool_result",
"data": {"name": fc["name"], "result": result}
}
# Note: In a real implementation, we'd need to send this result back
# to the model in a new turn. For this demo, we just verify execution.
except Exception as e:
yield {"event": "error", "data": f"Tool execution failed: {e}"}
except json.JSONDecodeError:
pass
except Exception as e:
yield {"event": "error", "data": str(e)}
yield {"event": "done", "data": "stop"}