Spaces:
Runtime error
Runtime error
| import os | |
| from enum import Enum | |
| from typing import Any, Optional, Union | |
| import instructor | |
| import weave | |
| from PIL import Image | |
| from ..utils import base64_encode_image | |
| class ClientType(Enum, str): | |
| GEMINI = "gemini" | |
| MISTRAL = "mistral" | |
| class LLMClient(weave.Model): | |
| model_name: str | |
| client_type: ClientType | |
| def __init__(self, model_name: str, client_type: ClientType): | |
| super().__init__(model_name=model_name, client_type=client_type) | |
| def execute_gemini_sdk( | |
| self, | |
| user_prompt: Union[str, list[str]], | |
| system_prompt: Optional[Union[str, list[str]]] = None, | |
| schema: Optional[Any] = None, | |
| ) -> Union[str, Any]: | |
| import google.generativeai as genai | |
| system_prompt = ( | |
| [system_prompt] if isinstance(system_prompt, str) else system_prompt | |
| ) | |
| user_prompt = [user_prompt] if isinstance(user_prompt, str) else user_prompt | |
| genai.configure(api_key=os.environ.get("GOOGLE_API_KEY")) | |
| model = genai.GenerativeModel(self.model_name) | |
| generation_config = ( | |
| None | |
| if schema is None | |
| else genai.GenerationConfig( | |
| response_mime_type="application/json", response_schema=list[schema] | |
| ) | |
| ) | |
| response = model.generate_content( | |
| system_prompt + user_prompt, generation_config=generation_config | |
| ) | |
| return response.text if schema is None else response | |
| def execute_mistral_sdk( | |
| self, | |
| user_prompt: Union[str, list[str]], | |
| system_prompt: Optional[Union[str, list[str]]] = None, | |
| schema: Optional[Any] = None, | |
| ) -> Union[str, Any]: | |
| from mistralai import Mistral | |
| system_prompt = ( | |
| [system_prompt] if isinstance(system_prompt, str) else system_prompt | |
| ) | |
| user_prompt = [user_prompt] if isinstance(user_prompt, str) else user_prompt | |
| system_messages = [{"type": "text", "text": prompt} for prompt in system_prompt] | |
| user_messages = [] | |
| for prompt in user_prompt: | |
| if isinstance(prompt, Image.Image): | |
| user_messages.append( | |
| { | |
| "type": "image_url", | |
| "image_url": base64_encode_image(prompt, "image/png"), | |
| } | |
| ) | |
| else: | |
| user_messages.append({"type": "text", "text": prompt}) | |
| messages = [ | |
| {"role": "system", "content": system_messages}, | |
| {"role": "user", "content": user_messages}, | |
| ] | |
| client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY")) | |
| client = instructor.from_mistral(client) | |
| response = ( | |
| client.chat.complete(model=self.model_name, messages=messages) | |
| if schema is None | |
| else client.messages.create( | |
| response_model=schema, messages=messages, temperature=0 | |
| ) | |
| ) | |
| return response.choices[0].message.content | |
| def predict( | |
| self, | |
| user_prompt: Union[str, list[str]], | |
| system_prompt: Optional[Union[str, list[str]]] = None, | |
| schema: Optional[Any] = None, | |
| ) -> Union[str, Any]: | |
| if self.client_type == ClientType.GEMINI: | |
| return self.execute_gemini_sdk(user_prompt, system_prompt, schema) | |
| elif self.client_type == ClientType.MISTRAL: | |
| return self.execute_mistral_sdk(user_prompt, system_prompt, schema) | |
| else: | |
| raise ValueError(f"Invalid client type: {self.client_type}") | |