import json import re from typing import List, Dict, Any, Union, Optional import io import os import base64 from PIL import Image import mimetypes import litellm from litellm import completion, completion_cost from dotenv import load_dotenv import random load_dotenv() class LiteLLMWrapper: """Wrapper for LiteLLM to support multiple models and logging""" def __init__( self, model_name: str = "gpt-4-vision-preview", temperature: float = 0.7, print_cost: bool = False, verbose: bool = False, use_langfuse: bool = True, ): """ Initialize the LiteLLM wrapper Args: model_name: Name of the model to use (e.g. "azure/gpt-4", "vertex_ai/gemini-pro") temperature: Temperature for completion print_cost: Whether to print the cost of the completion verbose: Whether to print verbose output use_langfuse: Whether to enable Langfuse logging """ self.model_name = model_name self.temperature = temperature self.print_cost = print_cost self.verbose = verbose self.accumulated_cost = 0 # Handle Gemini API key fallback mechanism if "gemini" in model_name.lower(): self._setup_gemini_api_key() if self.verbose: os.environ['LITELLM_LOG'] = 'DEBUG' # Set langfuse callback only if enabled if use_langfuse: litellm.success_callback = ["langfuse"] litellm.failure_callback = ["langfuse"] def _setup_gemini_api_key(self): """Setup Gemini API key with fallback mechanism for multiple keys.""" from dotenv import load_dotenv load_dotenv(override=True) gemini_key_env = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY") if not gemini_key_env: raise ValueError("No API_KEY found. Please set the `GEMINI_API_KEY` or `GOOGLE_API_KEY` environment variable.") # Support comma-separated list of API keys with random selection if ',' in gemini_key_env: keys = [key.strip() for key in gemini_key_env.split(',') if key.strip()] if not keys: raise ValueError("No valid API keys found in GEMINI_API_KEY list.") api_key = random.choice(keys) print(f"Selected random Gemini API key from {len(keys)} available keys: {api_key[:20]}...") else: api_key = gemini_key_env print(f"Using single Gemini API key: {api_key[:20]}...") # Set the selected API key for LiteLLM os.environ["GEMINI_API_KEY"] = api_key os.environ["GOOGLE_API_KEY"] = api_key def _encode_file(self, file_path: Union[str, Image.Image]) -> str: """ Encode local file or PIL Image to base64 string Args: file_path: Path to local file or PIL Image object Returns: Base64 encoded file string """ if isinstance(file_path, Image.Image): buffered = io.BytesIO() file_path.save(buffered, format="PNG") return base64.b64encode(buffered.getvalue()).decode("utf-8") else: with open(file_path, "rb") as file: return base64.b64encode(file.read()).decode("utf-8") def _get_mime_type(self, file_path: str) -> str: """ Get the MIME type of a file based on its extension Args: file_path: Path to the file Returns: MIME type as a string (e.g., "image/jpeg", "audio/mp3") """ mime_type, _ = mimetypes.guess_type(file_path) if mime_type is None: raise ValueError(f"Unsupported file type: {file_path}") return mime_type def __call__(self, messages: List[Dict[str, Any]], metadata: Optional[Dict[str, Any]] = None) -> str: """ Process messages and return completion Args: messages: List of message dictionaries with 'type' and 'content' keys metadata: Optional metadata to pass to litellm completion, e.g. for Langfuse tracking Returns: Generated text response """ if metadata is None: print("No metadata provided, using empty metadata") metadata = {} metadata["trace_name"] = f"litellm-completion-{self.model_name}" # Convert messages to LiteLLM format formatted_messages = [] for msg in messages: if msg["type"] == "text": formatted_messages.append({ "role": "user", "content": [{"type": "text", "text": msg["content"]}] }) elif msg["type"] in ["image", "audio", "video"]: # Check if content is a local file path or PIL Image if isinstance(msg["content"], Image.Image) or os.path.isfile(msg["content"]): try: if isinstance(msg["content"], Image.Image): mime_type = "image/png" else: mime_type = self._get_mime_type(msg["content"]) base64_data = self._encode_file(msg["content"]) data_url = f"data:{mime_type};base64,{base64_data}" except ValueError as e: print(f"Error processing file {msg['content']}: {e}") continue else: data_url = msg["content"] # Append the formatted message based on the model if "gemini" in self.model_name: formatted_messages.append({ "role": "user", "content": [ { "type": "image_url", "image_url": data_url } ] }) elif "gpt" in self.model_name: # GPT and other models expect a different format if msg["type"] == "image": # Default format for images and videos in GPT formatted_messages.append({ "role": "user", "content": [ { "type": f"image_url", f"{msg['type']}_url": { "url": data_url, "detail": "high" } } ] }) else: raise ValueError("For GPT, only text and image inferencing are supported") else: raise ValueError("Only support Gemini and Gpt for Multimodal capability now") try: # if it's openai o series model, set temperature to None and reasoning_effort to "medium" if (re.match(r"^o\d+.*$", self.model_name) or re.match(r"^openai/o.*$", self.model_name)): self.temperature = None self.reasoning_effort = "medium" response = completion( model=self.model_name, messages=formatted_messages, temperature=self.temperature, reasoning_effort=self.reasoning_effort, metadata=metadata, max_retries=99 ) else: response = completion( model=self.model_name, messages=formatted_messages, temperature=self.temperature, metadata=metadata, max_retries=99 ) if self.print_cost: # pass your response from completion to completion_cost cost = completion_cost(completion_response=response) formatted_string = f"Cost: ${float(cost):.10f}" # print(formatted_string) self.accumulated_cost += cost print(f"Accumulated Cost: ${self.accumulated_cost:.10f}") content = response.choices[0].message.content if content is None: print(f"Got null response from model. Full response: {response}") return content except Exception as e: print(f"Error in model completion: {e}") return str(e) if __name__ == "__main__": pass