"""Google AI provider implementation (Gemini models).""" import json import time from typing import Any, AsyncIterator import httpx from app.models.providers.base import ( AuthenticationError, BaseProvider, CompletionResponse, ModelInfo, ModelNotFoundError, ProviderError, RateLimitError, TokenUsage, ) class GoogleProvider(BaseProvider): """Google AI API provider supporting Gemini models.""" PROVIDER_NAME = "google" DEFAULT_BASE_URL = "https://generativelanguage.googleapis.com/v1beta" # Model definitions with pricing (per 1K tokens) MODELS = { # Gemini 2.5 Series "gemini-2.5-pro": ModelInfo( id="gemini-2.5-pro", name="Gemini 2.5 Pro", provider="google", context_window=2097152, max_output_tokens=8192, supports_functions=True, supports_vision=True, supports_streaming=True, cost_per_1k_input=0.00125, cost_per_1k_output=0.005, ), "gemini-2.5-flash": ModelInfo( id="gemini-2.5-flash", name="Gemini 2.5 Flash", provider="google", context_window=1048576, max_output_tokens=8192, supports_functions=True, supports_vision=True, supports_streaming=True, cost_per_1k_input=0.000075, cost_per_1k_output=0.0003, ), # Gemini 2.0 Series "gemini-2.0-flash": ModelInfo( id="gemini-2.0-flash", name="Gemini 2.0 Flash", provider="google", context_window=1048576, max_output_tokens=8192, supports_functions=True, supports_vision=True, supports_streaming=True, cost_per_1k_input=0.0, cost_per_1k_output=0.0, ), "gemini-2.0-flash-lite": ModelInfo( id="gemini-2.0-flash-lite", name="Gemini 2.0 Flash Lite", provider="google", context_window=524288, max_output_tokens=8192, supports_functions=True, supports_vision=True, supports_streaming=True, cost_per_1k_input=0.0, cost_per_1k_output=0.0, ), # Gemini 3.0 Series (Preview) "gemini-3-flash-preview": ModelInfo( id="gemini-3-flash-preview", name="Gemini 3 Flash Preview", provider="google", context_window=1048576, max_output_tokens=8192, supports_functions=True, supports_vision=True, supports_streaming=True, cost_per_1k_input=0.0, cost_per_1k_output=0.0, ), "gemini-3.1-flash-lite-preview": ModelInfo( id="gemini-3.1-flash-lite-preview", name="Gemini 3.1 Flash Lite Preview", provider="google", context_window=524288, max_output_tokens=8192, supports_functions=True, supports_vision=True, supports_streaming=True, cost_per_1k_input=0.0, cost_per_1k_output=0.0, ), # Gemini 1.5 Series (Stable) "gemini-1.5-pro": ModelInfo( id="gemini-1.5-pro", name="Gemini 1.5 Pro", provider="google", context_window=2097152, max_output_tokens=8192, supports_functions=True, supports_vision=True, supports_streaming=True, cost_per_1k_input=0.00125, cost_per_1k_output=0.005, ), "gemini-1.5-flash": ModelInfo( id="gemini-1.5-flash", name="Gemini 1.5 Flash", provider="google", context_window=1048576, max_output_tokens=8192, supports_functions=True, supports_vision=True, supports_streaming=True, cost_per_1k_input=0.000075, cost_per_1k_output=0.0003, ), "gemini-pro": ModelInfo( id="gemini-pro", name="Gemini Pro", provider="google", context_window=32760, max_output_tokens=8192, supports_functions=True, supports_vision=False, supports_streaming=True, cost_per_1k_input=0.0005, cost_per_1k_output=0.0015, ), } # Aliases MODEL_ALIASES = { "gemini-flash": "gemini-2.5-flash", "gemini-pro-latest": "gemini-2.5-pro", "gemini-1.5": "gemini-1.5-pro", } def __init__( self, api_key: str, base_url: str | None = None, timeout: float = 60.0, max_retries: int = 3, rate_limit_rpm: int = 60, ): super().__init__( api_key=api_key, base_url=base_url or self.DEFAULT_BASE_URL, timeout=timeout, max_retries=max_retries, rate_limit_rpm=rate_limit_rpm, ) self._client: httpx.AsyncClient | None = None async def initialize(self) -> None: """Initialize the HTTP client.""" self._client = httpx.AsyncClient( base_url=self.base_url, headers={"Content-Type": "application/json"}, timeout=self.timeout, ) async def shutdown(self) -> None: """Close the HTTP client.""" if self._client: await self._client.aclose() self._client = None async def _ensure_client(self) -> httpx.AsyncClient: """Ensure client is initialized.""" if not self._client: await self.initialize() return self._client # type: ignore def _resolve_model(self, model: str) -> str: """Resolve model alias to full model ID.""" return self.MODEL_ALIASES.get(model, model) def get_models(self) -> list[ModelInfo]: """Get available Google AI models.""" return list(self.MODELS.values()) def _convert_messages( self, messages: list[dict[str, Any]] ) -> tuple[str | None, list[dict[str, Any]]]: """Convert OpenAI-style messages to Gemini format. Returns: Tuple of (system_instruction, contents) """ system_instruction: str | None = None contents: list[dict[str, Any]] = [] for msg in messages: role = msg["role"] content = msg["content"] if role == "system": system_instruction = content elif role == "assistant": contents.append({ "role": "model", "parts": [{"text": content}] if isinstance(content, str) else content, }) elif role == "user": contents.append({ "role": "user", "parts": [{"text": content}] if isinstance(content, str) else content, }) elif role == "function": # Function response contents.append({ "role": "function", "parts": [{ "functionResponse": { "name": msg.get("name", "function"), "response": {"result": content}, } }], }) elif role == "tool": # Tool response contents.append({ "role": "function", "parts": [{ "functionResponse": { "name": msg.get("tool_call_id", "tool"), "response": {"result": content}, } }], }) return system_instruction, contents def _convert_tools( self, tools: list[dict[str, Any]] | None ) -> list[dict[str, Any]] | None: """Convert OpenAI-style tools to Gemini format.""" if not tools: return None function_declarations = [] for tool in tools: if tool.get("type") == "function": func = tool["function"] function_declarations.append({ "name": func["name"], "description": func.get("description", ""), "parameters": func.get("parameters", {"type": "object", "properties": {}}), }) return [{"functionDeclarations": function_declarations}] if function_declarations else None async def complete( self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, functions: list[dict[str, Any]] | None = None, function_call: str | dict[str, str] | None = None, tools: list[dict[str, Any]] | None = None, tool_choice: str | dict[str, Any] | None = None, stop: list[str] | None = None, **kwargs: Any, ) -> CompletionResponse: """Generate a completion using Google AI API.""" import logging logger = logging.getLogger(__name__) logger.info(f"GoogleProvider.complete called with model={model}") await self._acquire_rate_limit() model = self._resolve_model(model) logger.info(f"GoogleProvider after resolve: model={model}") model_info = self.get_model_info(model) logger.info(f"GoogleProvider model_info: {model_info}") if not model_info: raise ModelNotFoundError(self.PROVIDER_NAME, model) client = await self._ensure_client() # Convert messages system_instruction, contents = self._convert_messages(messages) # Build request payload payload: dict[str, Any] = { "contents": contents, "generationConfig": { "temperature": temperature, }, } if max_tokens: payload["generationConfig"]["maxOutputTokens"] = max_tokens if stop: payload["generationConfig"]["stopSequences"] = stop if system_instruction: payload["systemInstruction"] = {"parts": [{"text": system_instruction}]} # Convert tools gemini_tools = self._convert_tools(tools) if not gemini_tools and functions: gemini_tools = [{ "functionDeclarations": [ { "name": f["name"], "description": f.get("description", ""), "parameters": f.get("parameters", {"type": "object", "properties": {}}), } for f in functions ] }] if gemini_tools: payload["tools"] = gemini_tools start_time = time.time() url = f"/models/{model}:generateContent?key={self.api_key}" try: response = await self._retry_with_backoff( self._make_request, client, url, payload ) except httpx.HTTPStatusError as e: self._handle_http_error(e) latency_ms = (time.time() - start_time) * 1000 # Parse response candidates = response.get("candidates", []) if not candidates: raise ProviderError("No candidates in response", self.PROVIDER_NAME) candidate = candidates[0] content_parts = candidate.get("content", {}).get("parts", []) # Extract text content and function calls text_content = "" tool_calls = [] for part in content_parts: if "text" in part: text_content += part["text"] elif "functionCall" in part: fc = part["functionCall"] tool_calls.append({ "id": f"call_{fc['name']}", "type": "function", "function": { "name": fc["name"], "arguments": json.dumps(fc.get("args", {})), }, }) # Parse usage usage_data = response.get("usageMetadata", {}) usage = TokenUsage( prompt_tokens=usage_data.get("promptTokenCount", 0), completion_tokens=usage_data.get("candidatesTokenCount", 0), total_tokens=usage_data.get("totalTokenCount", 0), ) cost = self.calculate_cost(model, usage) self._track_usage(usage, cost) # Map finish reason finish_reason_map = { "STOP": "stop", "MAX_TOKENS": "length", "SAFETY": "content_filter", "RECITATION": "content_filter", } finish_reason = finish_reason_map.get( candidate.get("finishReason", ""), candidate.get("finishReason") ) return CompletionResponse( content=text_content, model=model, provider=self.PROVIDER_NAME, usage=usage, finish_reason=finish_reason, function_call=None, tool_calls=tool_calls if tool_calls else None, raw_response=response, latency_ms=latency_ms, cost=cost, ) async def _make_request( self, client: httpx.AsyncClient, url: str, payload: dict[str, Any] ) -> dict[str, Any]: """Make the API request.""" response = await client.post(url, json=payload) response.raise_for_status() return response.json() def _handle_http_error(self, error: httpx.HTTPStatusError) -> None: """Handle HTTP errors from Google AI.""" status = error.response.status_code try: body = error.response.json() message = body.get("error", {}).get("message", str(error)) except Exception: message = str(error) if status == 401 or status == 403: raise AuthenticationError(self.PROVIDER_NAME, message) elif status == 429: retry_after = error.response.headers.get("retry-after") raise RateLimitError( self.PROVIDER_NAME, retry_after=float(retry_after) if retry_after else None, message=message, ) elif status == 404: # Extract model name from URL if possible model_name = "unknown" url = str(error.request.url) if "/models/" in url: try: model_name = url.split("/models/")[1].split(":")[0] except Exception: pass raise ModelNotFoundError(self.PROVIDER_NAME, model_name) else: raise ProviderError(message, self.PROVIDER_NAME, status) async def stream( self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, **kwargs: Any, ) -> AsyncIterator[str]: """Stream a completion from Google AI.""" await self._acquire_rate_limit() model = self._resolve_model(model) model_info = self.get_model_info(model) if not model_info: raise ModelNotFoundError(self.PROVIDER_NAME, model) client = await self._ensure_client() system_instruction, contents = self._convert_messages(messages) payload: dict[str, Any] = { "contents": contents, "generationConfig": { "temperature": temperature, }, } if max_tokens: payload["generationConfig"]["maxOutputTokens"] = max_tokens if system_instruction: payload["systemInstruction"] = {"parts": [{"text": system_instruction}]} url = f"/models/{model}:streamGenerateContent?key={self.api_key}&alt=sse" try: async with client.stream("POST", url, json=payload) as response: response.raise_for_status() async for line in response.aiter_lines(): if line.startswith("data: "): data = line[6:] try: chunk = json.loads(data) candidates = chunk.get("candidates", []) if candidates: parts = candidates[0].get("content", {}).get("parts", []) for part in parts: if "text" in part: yield part["text"] except json.JSONDecodeError: continue except httpx.HTTPStatusError as e: self._handle_http_error(e)