Spaces:
Build error
Build error
| import os | |
| import time | |
| import httpx | |
| import warnings | |
| from typing import List, Dict, Optional | |
| from smolagents import ApiModel, ChatMessage | |
| class GeminiApiModel(ApiModel): | |
| """ | |
| ApiModel implementation using the Google Gemini API via direct HTTP requests. | |
| """ | |
| def __init__( | |
| self, | |
| model_id: str = "gemini-pro", | |
| api_key: Optional[str] = None, | |
| **kwargs, | |
| ): | |
| """ | |
| Initializes the GeminiApiModel. | |
| Args: | |
| model_id (str): The Gemini model ID to use (e.g., "gemini-pro"). | |
| api_key (str, optional): Google AI Studio API key. Defaults to GEMINI_API_KEY environment variable. | |
| **kwargs: Additional keyword arguments passed to the parent ApiModel. | |
| """ | |
| self.model_id = model_id | |
| # Prefer explicitly passed key, fallback to environment variable | |
| self.api_key = api_key if api_key else os.environ.get("GEMINI_API_KEY") | |
| if not self.api_key: | |
| warnings.warn( | |
| "GEMINI_API_KEY not provided via argument or environment variable. API calls will likely fail.", | |
| UserWarning, | |
| ) | |
| # Gemini API doesn't inherently support complex role structures or function calling like OpenAI. | |
| # We'll flatten messages for simplicity. | |
| super().__init__( | |
| model_id=model_id, | |
| flatten_messages_as_text=True, # Flatten messages to a single text prompt | |
| **kwargs, | |
| ) | |
| def create_client(self): | |
| """No dedicated client needed as we use httpx directly.""" | |
| return None # Or potentially return httpx client if reused | |
| def __call__( | |
| self, | |
| messages: List[Dict[str, str]], | |
| stop_sequences: Optional[ | |
| List[str] | |
| ] = None, # Note: Gemini API might not support stop sequences directly here | |
| grammar: Optional[ | |
| str | |
| ] = None, # Note: Gemini API doesn't support grammar directly | |
| tools_to_call_from: Optional[ | |
| List["Tool"] | |
| ] = None, # Note: Basic Gemini API doesn't support tools | |
| **kwargs, | |
| ) -> ChatMessage: | |
| """ | |
| Calls the Google Gemini API with the provided messages. | |
| Args: | |
| messages: A list of message dictionaries (e.g., [{'role': 'user', 'content': '...'}]). | |
| stop_sequences: Optional stop sequences (may not be supported). | |
| grammar: Optional grammar constraint (not supported). | |
| tools_to_call_from: Optional list of tools (not supported). | |
| **kwargs: Additional keyword arguments. | |
| Returns: | |
| A ChatMessage object containing the response. | |
| """ | |
| if not self.api_key: | |
| raise ValueError("GEMINI_API_KEY is not set.") | |
| # Prepare the prompt by concatenating message content | |
| # The Gemini Pro basic API expects a simple text prompt. | |
| prompt = self._messages_to_prompt(messages) | |
| prompt += ( | |
| "\n\n" | |
| + "If you have a result from a web search that looks helpful, please use httpx to get the HTML from the URL listed." | |
| + "You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string." | |
| ) | |
| # print(f"--- Gemini API prompt: ---\n{prompt}\n--- End of prompt ---") | |
| url = f"https://generativelanguage.googleapis.com/v1beta/models/{self.model_id}:generateContent?key={self.api_key}" | |
| headers = {"Content-Type": "application/json"} | |
| # Construct the payload according to Gemini API requirements | |
| data = {"contents": [{"parts": [{"text": prompt}]}]} | |
| # Add generation config if provided via kwargs (optional) | |
| generation_config = {} | |
| if "temperature" in kwargs: | |
| generation_config["temperature"] = kwargs["temperature"] | |
| if "max_output_tokens" in kwargs: | |
| generation_config["maxOutputTokens"] = kwargs["max_output_tokens"] | |
| # Add other relevant config parameters here if needed | |
| if generation_config: | |
| data["generationConfig"] = generation_config | |
| # Handle stop sequences if provided (basic support) | |
| # Note: This is a best-effort addition, check Gemini API docs for formal support | |
| if stop_sequences: | |
| if "generationConfig" not in data: | |
| data["generationConfig"] = {} | |
| # Assuming Gemini API might support 'stopSequences' in generationConfig | |
| data["generationConfig"]["stopSequences"] = stop_sequences | |
| raw_response = None | |
| try: | |
| response = httpx.post( | |
| url, headers=headers, json=data, timeout=120.0 | |
| ) # Increased timeout | |
| time.sleep(6) # Add delay to respect rate limits | |
| response.raise_for_status() | |
| response_json = response.json() | |
| raw_response = response_json # Store raw response | |
| # Parse the response - adjust based on actual Gemini API structure | |
| if "candidates" in response_json and response_json["candidates"]: | |
| part = response_json["candidates"][0]["content"]["parts"][0] | |
| if "text" in part: | |
| content = part["text"] | |
| # Check for "FINAL ANSWER: " and extract the rest of the string | |
| final_answer_marker = "FINAL ANSWER: " | |
| if final_answer_marker in content: | |
| content = content.split(final_answer_marker)[-1].strip() | |
| # Simulate token counts if available, otherwise default to 0 | |
| # The basic generateContent API might not return usage directly in the main response | |
| # It might be in safetyRatings or other metadata if enabled/available. | |
| # Setting to 0 for now as it's not reliably present in the simplest call. | |
| self.last_input_token_count = 0 | |
| self.last_output_token_count = 0 | |
| # If usage data becomes available in response_json, parse it here: | |
| # e.g., if response_json.get("usageMetadata"): | |
| # self.last_input_token_count = response_json["usageMetadata"].get("promptTokenCount", 0) | |
| # self.last_output_token_count = response_json["usageMetadata"].get("candidatesTokenCount", 0) | |
| return ChatMessage( | |
| role="assistant", content=content, raw=raw_response | |
| ) | |
| # Handle cases where the expected response structure isn't found | |
| error_content = f"Error or unexpected response format: {response_json}" | |
| return ChatMessage( | |
| role="assistant", content=error_content, raw=raw_response | |
| ) | |
| except httpx.RequestError as exc: | |
| error_content = ( | |
| f"An error occurred while requesting {exc.request.url!r}: {exc}" | |
| ) | |
| return ChatMessage( | |
| role="assistant", content=error_content, raw={"error": str(exc)} | |
| ) | |
| except httpx.HTTPStatusError as exc: | |
| error_content = f"Error response {exc.response.status_code} while requesting {exc.request.url!r}: {exc.response.text}" | |
| return ChatMessage( | |
| role="assistant", | |
| content=error_content, | |
| raw={ | |
| "error": str(exc), | |
| "status_code": exc.response.status_code, | |
| "response_text": exc.response.text, | |
| }, | |
| ) | |
| except Exception as e: | |
| error_content = f"An unexpected error occurred: {e}" | |
| return ChatMessage( | |
| role="assistant", content=error_content, raw={"error": str(e)} | |
| ) | |
| def _messages_to_prompt(self, messages: List[Dict[str, str]]) -> str: | |
| """Converts a list of messages into a single string prompt.""" | |
| # Simple concatenation, could be more sophisticated based on roles if needed | |
| # Ensure we handle cases where 'content' might not be a string (though it should be) | |
| return "\n".join([str(msg.get("content", "")) for msg in messages]) | |