| from typing import Any, Generator, List, Optional |
|
|
| from injector import inject |
|
|
| from taskweaver.llm.base import CompletionService, EmbeddingService, LLMServiceConfig |
| from taskweaver.llm.util import ChatMessageType, format_chat_message |
|
|
|
|
| class GoogleGenAIServiceConfig(LLMServiceConfig): |
| def _configure(self) -> None: |
| self._set_name("google_genai") |
|
|
| shared_api_key = self.llm_module_config.api_key |
| self.api_key = self._get_str( |
| "api_key", |
| shared_api_key if shared_api_key is not None else "", |
| ) |
| shared_model = self.llm_module_config.model |
| self.model = self._get_str( |
| "model", |
| shared_model if shared_model is not None else "gemini-pro", |
| ) |
| shared_backup_model = self.llm_module_config.backup_model |
| self.backup_model = self._get_str( |
| "backup_model", |
| shared_backup_model if shared_backup_model is not None else self.model, |
| ) |
| shared_embedding_model = self.llm_module_config.embedding_model |
| self.embedding_model = self._get_str( |
| "embedding_model", |
| shared_embedding_model if shared_embedding_model is not None else self.model, |
| ) |
|
|
| shared_response_format = self.llm_module_config.response_format |
| self.response_format = self._get_enum( |
| "response_format", |
| options=["json_object", "text"], |
| default=shared_response_format if shared_response_format is not None else "text", |
| ) |
| self.temperature = self._get_float("temperature", 0.9) |
| self.max_output_tokens = self._get_int("max_output_tokens", 1000) |
| self.top_k = self._get_int("top_k", 1) |
| self.top_p = self._get_float("top_p", 0) |
|
|
|
|
| class GoogleGenAIService(CompletionService, EmbeddingService): |
| @inject |
| def __init__(self, config: GoogleGenAIServiceConfig): |
| self.config = config |
| genai = self.import_genai_module() |
| genai.configure(api_key=self.config.api_key) |
| safety_settings = [ |
| { |
| "category": "HARM_CATEGORY_HARASSMENT", |
| "threshold": "BLOCK_MEDIUM_AND_ABOVE", |
| }, |
| { |
| "category": "HARM_CATEGORY_HATE_SPEECH", |
| "threshold": "BLOCK_MEDIUM_AND_ABOVE", |
| }, |
| { |
| "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", |
| "threshold": "BLOCK_MEDIUM_AND_ABOVE", |
| }, |
| { |
| "category": "HARM_CATEGORY_DANGEROUS_CONTENT", |
| "threshold": "BLOCK_MEDIUM_AND_ABOVE", |
| }, |
| ] |
|
|
| self.model = genai.GenerativeModel( |
| model_name=self.config.model, |
| generation_config={ |
| "temperature": self.config.temperature, |
| "top_p": self.config.top_p, |
| "top_k": self.config.top_k, |
| "max_output_tokens": self.config.max_output_tokens, |
| }, |
| safety_settings=safety_settings, |
| ) |
|
|
| def import_genai_module(self): |
| try: |
| import google.generativeai as genai |
| except Exception: |
| raise Exception( |
| "Package google-generativeai is required for using Google Gemini API. " |
| "Please install it manually by running: `pip install google-generativeai`", |
| ) |
| return genai |
|
|
| def chat_completion( |
| self, |
| messages: List[ChatMessageType], |
| use_backup_engine: bool = False, |
| stream: bool = True, |
| temperature: Optional[float] = None, |
| max_tokens: Optional[int] = None, |
| top_p: Optional[float] = None, |
| stop: Optional[List[str]] = None, |
| **kwargs: Any, |
| ) -> Generator[ChatMessageType, None, None]: |
| try: |
| return self._chat_completion( |
| messages=messages, |
| use_backup_engine=use_backup_engine, |
| stream=stream, |
| temperature=temperature, |
| max_tokens=max_tokens, |
| top_p=top_p, |
| stop=stop, |
| **kwargs, |
| ) |
| except Exception: |
| return self._completion( |
| messages=messages, |
| use_backup_engine=use_backup_engine, |
| stream=stream, |
| temperature=temperature, |
| max_tokens=max_tokens, |
| top_p=top_p, |
| stop=stop, |
| **kwargs, |
| ) |
|
|
| def _chat_completion( |
| self, |
| messages: List[ChatMessageType], |
| use_backup_engine: bool = False, |
| stream: bool = True, |
| temperature: Optional[float] = None, |
| max_tokens: Optional[int] = None, |
| top_p: Optional[float] = None, |
| stop: Optional[List[str]] = None, |
| **kwargs: Any, |
| ) -> Generator[ChatMessageType, None, None]: |
| genai_messages = [] |
| prev_role = "" |
| for msg in messages: |
| if msg["role"] == "system": |
| genai_messages.append({"role": "user", "parts": [msg["content"]]}) |
| genai_messages.append( |
| { |
| "role": "model", |
| "parts": ["I understand your requirements, and I will assist you in the conversations."], |
| }, |
| ) |
| prev_role = "model" |
| elif msg["role"] == "user": |
| if prev_role == "user": |
| |
| genai_messages.append({"role": "model", "parts": [" "]}) |
| genai_messages.append({"role": "user", "parts": [msg["content"]]}) |
| prev_role = "user" |
| elif msg["role"] == "assistant": |
| genai_messages.append({"role": "model", "parts": [msg["content"]]}) |
| prev_role = "model" |
| else: |
| raise Exception(f"Invalid role: {msg['role']}") |
|
|
| if stream is False: |
| response = self.model.generate_content(genai_messages, stream=False) |
| yield format_chat_message("assistant", response.text) |
|
|
| response = self.model.generate_content(genai_messages, stream=True) |
| for chunk_obj in response: |
| yield format_chat_message("assistant", chunk_obj.text) |
|
|
| def get_embeddings(self, strings: List[str]) -> List[List[float]]: |
| genai = self.import_genai_module() |
| embedding_results = genai.embed_content( |
| model=self.config.embedding_model, |
| content=strings, |
| task_type="semantic_similarity", |
| ) |
| return embedding_results["embedding"] |
|
|