Spaces:
Sleeping
Sleeping
| # Ref: https://github.com/zou-group/textgrad/blob/main/textgrad/engine/gemini.py | |
| # Ref: https://ai.google.dev/gemini-api/docs/quickstart?lang=python | |
| # Changed to use the new google-genai package May 25, 2025 | |
| # Ref: https://ai.google.dev/gemini-api/docs/migrate | |
| try: | |
| # import google.generativeai as genai | |
| from google import genai | |
| from google.genai import types | |
| except ImportError: | |
| raise ImportError("If you'd like to use Gemini models, please install the google-genai package by running `pip install google-genai`, and add 'GOOGLE_API_KEY' to your environment variables.") | |
| import os | |
| import platformdirs | |
| from tenacity import ( | |
| retry, | |
| stop_after_attempt, | |
| wait_random_exponential, | |
| ) | |
| import base64 | |
| import json | |
| from typing import List, Union | |
| from .base import EngineLM, CachedEngine | |
| from .engine_utils import get_image_type_from_bytes | |
| import io | |
| from PIL import Image | |
| class ChatGemini(EngineLM, CachedEngine): | |
| SYSTEM_PROMPT = "You are a helpful, creative, and smart assistant." | |
| def __init__( | |
| self, | |
| model_string="gemini-pro", | |
| use_cache: bool=False, | |
| system_prompt=SYSTEM_PROMPT, | |
| is_multimodal: bool=False, | |
| ): | |
| self.use_cache = use_cache | |
| self.model_string = model_string | |
| self.system_prompt = system_prompt | |
| assert isinstance(self.system_prompt, str) | |
| self.is_multimodal = is_multimodal | |
| if self.use_cache: | |
| root = platformdirs.user_cache_dir("agentflow") | |
| cache_path = os.path.join(root, f"cache_gemini_{model_string}.db") | |
| super().__init__(cache_path=cache_path) | |
| if os.getenv("GOOGLE_API_KEY") is None: | |
| raise ValueError("Please set the GOOGLE_API_KEY environment variable if you'd like to use Gemini models.") | |
| self.client = genai.Client(api_key=os.getenv("GOOGLE_API_KEY")) | |
| def __call__(self, prompt, **kwargs): | |
| return self.generate(prompt, **kwargs) | |
| def generate(self, content: Union[str, List[Union[str, bytes]]], system_prompt=None, **kwargs): | |
| try: | |
| if isinstance(content, str): | |
| return self._generate_from_single_prompt(content, system_prompt=system_prompt, **kwargs) | |
| elif isinstance(content, list): | |
| if all(isinstance(item, str) for item in content): | |
| full_text = "\n".join(content) | |
| return self._generate_from_single_prompt(full_text, system_prompt=system_prompt, **kwargs) | |
| elif any(isinstance(item, bytes) for item in content): | |
| if not self.is_multimodal: | |
| raise NotImplementedError( | |
| f"Multimodal generation is only supported for {self.model_string}. " | |
| "Consider using a multimodal model like 'gpt-4o'." | |
| ) | |
| return self._generate_from_multiple_input(content, system_prompt=system_prompt, **kwargs) | |
| else: | |
| raise ValueError("Unsupported content in list: only str or bytes are allowed.") | |
| except Exception as e: | |
| print(f"Error in generate method: {str(e)}") | |
| print(f"Error type: {type(e).__name__}") | |
| print(f"Error details: {e.args}") | |
| return { | |
| "error": type(e).__name__, | |
| "message": str(e), | |
| "details": getattr(e, 'args', None) | |
| } | |
| def _generate_from_single_prompt( | |
| self, prompt: str, system_prompt=None, temperature=0., max_tokens=4000, top_p=0.99, **kwargs | |
| ): | |
| sys_prompt_arg = system_prompt if system_prompt else self.system_prompt | |
| if self.use_cache: | |
| cache_or_none = self._check_cache(sys_prompt_arg + prompt) | |
| if cache_or_none is not None: | |
| return cache_or_none | |
| # messages = [{'role': 'user', 'parts': [prompt]}] | |
| messages = [prompt] | |
| response = self.client.models.generate_content( | |
| model=self.model_string, | |
| contents=messages, | |
| config=types.GenerateContentConfig( | |
| system_instruction=sys_prompt_arg, | |
| max_output_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| candidate_count=1, | |
| ) | |
| ) | |
| response_text = response.text | |
| if self.use_cache: | |
| self._save_cache(sys_prompt_arg + prompt, response_text) | |
| return response_text | |
| def _format_content(self, content: List[Union[str, bytes]]) -> List[dict]: | |
| formatted_content = [] | |
| for item in content: | |
| if isinstance(item, bytes): | |
| image_obj = Image.open(io.BytesIO(item)) | |
| formatted_content.append(image_obj) | |
| elif isinstance(item, str): | |
| formatted_content.append(item) | |
| else: | |
| raise ValueError(f"Unsupported input type: {type(item)}") | |
| return formatted_content | |
| def _generate_from_multiple_input( | |
| self, content: List[Union[str, bytes]], system_prompt=None, temperature=0, max_tokens=4000, top_p=0.99, **kwargs | |
| ): | |
| sys_prompt_arg = system_prompt if system_prompt else self.system_prompt | |
| formatted_content = self._format_content(content) | |
| if self.use_cache: | |
| cache_key = sys_prompt_arg + json.dumps(formatted_content) | |
| cache_or_none = self._check_cache(cache_key) | |
| if cache_or_none is not None: | |
| return cache_or_none | |
| response = self.client.models.generate_content( | |
| model=self.model_string, | |
| contents=formatted_content, | |
| config=types.GenerateContentConfig( | |
| system_instruction=sys_prompt_arg, | |
| max_output_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| candidate_count=1 | |
| ) | |
| ) | |
| response_text = response.text | |
| if self.use_cache: | |
| self._save_cache(cache_key, response_text) | |
| return response_text |