Spaces:
Sleeping
Sleeping
File size: 8,867 Bytes
7d4db27 dd65f93 7d4db27 dd65f93 7d4db27 dd65f93 7d4db27 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
import json
import re
from typing import List, Dict, Any, Union, Optional
import io
import os
import base64
from PIL import Image
import mimetypes
import litellm
from litellm import completion, completion_cost
from dotenv import load_dotenv
import random
load_dotenv()
class LiteLLMWrapper:
"""Wrapper for LiteLLM to support multiple models and logging"""
def __init__(
self,
model_name: str = "gpt-4-vision-preview",
temperature: float = 0.7,
print_cost: bool = False,
verbose: bool = False,
use_langfuse: bool = True,
):
"""
Initialize the LiteLLM wrapper
Args:
model_name: Name of the model to use (e.g. "azure/gpt-4", "vertex_ai/gemini-pro")
temperature: Temperature for completion
print_cost: Whether to print the cost of the completion
verbose: Whether to print verbose output
use_langfuse: Whether to enable Langfuse logging
"""
self.model_name = model_name
self.temperature = temperature
self.print_cost = print_cost
self.verbose = verbose
self.accumulated_cost = 0
# Handle Gemini API key fallback mechanism
if "gemini" in model_name.lower():
self._setup_gemini_api_key()
if self.verbose:
os.environ['LITELLM_LOG'] = 'DEBUG'
# Set langfuse callback only if enabled
if use_langfuse:
litellm.success_callback = ["langfuse"]
litellm.failure_callback = ["langfuse"]
def _setup_gemini_api_key(self):
"""Setup Gemini API key with fallback mechanism for multiple keys."""
from dotenv import load_dotenv
load_dotenv(override=True)
gemini_key_env = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
if not gemini_key_env:
raise ValueError("No API_KEY found. Please set the `GEMINI_API_KEY` or `GOOGLE_API_KEY` environment variable.")
# Support comma-separated list of API keys with random selection
if ',' in gemini_key_env:
keys = [key.strip() for key in gemini_key_env.split(',') if key.strip()]
if not keys:
raise ValueError("No valid API keys found in GEMINI_API_KEY list.")
api_key = random.choice(keys)
print(f"Selected random Gemini API key from {len(keys)} available keys: {api_key[:20]}...")
else:
api_key = gemini_key_env
print(f"Using single Gemini API key: {api_key[:20]}...")
# Set the selected API key for LiteLLM
os.environ["GEMINI_API_KEY"] = api_key
os.environ["GOOGLE_API_KEY"] = api_key
def _encode_file(self, file_path: Union[str, Image.Image]) -> str:
"""
Encode local file or PIL Image to base64 string
Args:
file_path: Path to local file or PIL Image object
Returns:
Base64 encoded file string
"""
if isinstance(file_path, Image.Image):
buffered = io.BytesIO()
file_path.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
else:
with open(file_path, "rb") as file:
return base64.b64encode(file.read()).decode("utf-8")
def _get_mime_type(self, file_path: str) -> str:
"""
Get the MIME type of a file based on its extension
Args:
file_path: Path to the file
Returns:
MIME type as a string (e.g., "image/jpeg", "audio/mp3")
"""
mime_type, _ = mimetypes.guess_type(file_path)
if mime_type is None:
raise ValueError(f"Unsupported file type: {file_path}")
return mime_type
def __call__(self, messages: List[Dict[str, Any]], metadata: Optional[Dict[str, Any]] = None) -> str:
"""
Process messages and return completion
Args:
messages: List of message dictionaries with 'type' and 'content' keys
metadata: Optional metadata to pass to litellm completion, e.g. for Langfuse tracking
Returns:
Generated text response
"""
if metadata is None:
print("No metadata provided, using empty metadata")
metadata = {}
metadata["trace_name"] = f"litellm-completion-{self.model_name}"
# Convert messages to LiteLLM format
formatted_messages = []
for msg in messages:
if msg["type"] == "text":
formatted_messages.append({
"role": "user",
"content": [{"type": "text", "text": msg["content"]}]
})
elif msg["type"] in ["image", "audio", "video"]:
# Check if content is a local file path or PIL Image
if isinstance(msg["content"], Image.Image) or os.path.isfile(msg["content"]):
try:
if isinstance(msg["content"], Image.Image):
mime_type = "image/png"
else:
mime_type = self._get_mime_type(msg["content"])
base64_data = self._encode_file(msg["content"])
data_url = f"data:{mime_type};base64,{base64_data}"
except ValueError as e:
print(f"Error processing file {msg['content']}: {e}")
continue
else:
data_url = msg["content"]
# Append the formatted message based on the model
if "gemini" in self.model_name:
formatted_messages.append({
"role": "user",
"content": [
{
"type": "image_url",
"image_url": data_url
}
]
})
elif "gpt" in self.model_name:
# GPT and other models expect a different format
if msg["type"] == "image":
# Default format for images and videos in GPT
formatted_messages.append({
"role": "user",
"content": [
{
"type": f"image_url",
f"{msg['type']}_url": {
"url": data_url,
"detail": "high"
}
}
]
})
else:
raise ValueError("For GPT, only text and image inferencing are supported")
else:
raise ValueError("Only support Gemini and Gpt for Multimodal capability now")
try:
# if it's openai o series model, set temperature to None and reasoning_effort to "medium"
if (re.match(r"^o\d+.*$", self.model_name) or re.match(r"^openai/o.*$", self.model_name)):
self.temperature = None
self.reasoning_effort = "medium"
response = completion(
model=self.model_name,
messages=formatted_messages,
temperature=self.temperature,
reasoning_effort=self.reasoning_effort,
metadata=metadata,
max_retries=99
)
else:
response = completion(
model=self.model_name,
messages=formatted_messages,
temperature=self.temperature,
metadata=metadata,
max_retries=99
)
if self.print_cost:
# pass your response from completion to completion_cost
cost = completion_cost(completion_response=response)
formatted_string = f"Cost: ${float(cost):.10f}"
# print(formatted_string)
self.accumulated_cost += cost
print(f"Accumulated Cost: ${self.accumulated_cost:.10f}")
content = response.choices[0].message.content
if content is None:
print(f"Got null response from model. Full response: {response}")
return content
except Exception as e:
print(f"Error in model completion: {e}")
return str(e)
if __name__ == "__main__":
pass |