IsaacGHX
update
d12a6df
raw
history blame
14.6 kB
# TODO: The current implementation is not based on textgrad, but rather a direct implementation of the LiteLLM API.
# Detached from textgrad: https://github.com/zou-group/textgrad/blob/main/textgrad/engine_experimental/litellm.py
try:
import litellm
from litellm import supports_reasoning
except ImportError:
raise ImportError("If you'd like to use LiteLLM, please install the litellm package by running `pip install litellm`, and set appropriate API keys for the models you want to use.")
import os
import json
import base64
import platformdirs
import logging
from tenacity import (
retry,
stop_after_attempt,
wait_random_exponential,
)
from typing import List, Union, Optional, Any, Dict
from .base import EngineLM, CachedEngine
from .engine_utils import get_image_type_from_bytes
def validate_structured_output_model(model_string: str) -> bool:
"""
Check if the model supports structured outputs.
Args:
model_string: The name of the model to check
Returns:
True if the model supports structured outputs, False otherwise
"""
# Models that support structured outputs
structure_output_models = [
"gpt-4",
"claude-opus-4", "claude-sonnet-4", "claude-3.7-sonnet", "claude-3.5-sonnet", "claude-3-opus",
"gemini-",
]
return any(x in model_string.lower() for x in structure_output_models)
def validate_chat_model(model_string: str) -> bool:
# 99% of LiteLLM models are chat models
return True
def validate_reasoning_model(model_string: str) -> bool:
"""
Check if the model is a reasoning model.
Includes OpenAI o1/o3/o4 variants (non-pro), Claude models, and other LLMs known for reasoning.
"""
m = model_string.lower()
if supports_reasoning(model_string):
return True
# Hard ways
if any(x in m for x in ["o1", "o3", "o4"]) and not validate_pro_reasoning_model(model_string):
return True
if "claude" in m and not validate_pro_reasoning_model(model_string):
return True
extra = ["qwen-72b", "llama-3-70b", "mistral-large", "deepseek-reasoner", "xai/grok-3", "gemini-2.5-pro"]
if any(e in model_string.lower() for e in extra):
return True
return False
def validate_pro_reasoning_model(model_string: str) -> bool:
"""
Check if the model is a pro reasoning model:
OpenAI o1-pro, o3-pro, o4-pro, and Claude-4/Sonnet variants.
"""
m = model_string.lower()
if any(x in m for x in ["o1-pro", "o3-pro", "o4-pro"]):
return True
if any(x in m for x in ["claude-opus-4", "claude-sonnet-4", "claude-3.7-sonnet"]):
return True
return False
def validate_multimodal_model(model_string: str) -> bool:
"""
Check if the model supports multimodal inputs.
Args:
model_string: The name of the model to check
Returns:
True if the model supports multimodal inputs, False otherwise
"""
m = model_string.lower()
# Core multimodal models
multimodal_models = [
"gpt-4-vision", "gpt-4o", "gpt-4.1", # OpenAI multimodal
"gpt-4v", # alias for vision-capable GPT-4
"claude-sonnet", "claude-opus", # Claude multimodal variants
"gemini", # Base Gemini models are multimodal :contentReference[oaicite:0]{index=0}
"gpt-4v", # repeats for clarity
"llama-4", # reported as multimodal
"qwen-vl", "qwen2-vl", # Qwen vision-language models
]
# Add Gemini TTS / audio-capable variants (though audio is modality)
audio_models = ["-tts", "-flash-preview-tts", "-pro-preview-tts"]
if any(g in m for g in multimodal_models):
return True
if "gemini" in m and any(s in m for s in audio_models):
return True # E.g. gemini-2.5-flash-preview-tts
# Make sure we catch edge cases like "gpt-4v" or "gpt-4 vision"
if "vision" in m or "vl" in m:
return True
return False
class ChatLiteLLM(EngineLM, CachedEngine):
"""
LiteLLM implementation of the EngineLM interface.
This allows using any model supported by LiteLLM.
"""
DEFAULT_SYSTEM_PROMPT = "You are a helpful, creative, and smart assistant."
def __init__(
self,
model_string: str = "gpt-3.5-turbo",
use_cache: bool = False,
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
is_multimodal: bool = False,
**kwargs
):
"""
Initialize the LiteLLM engine.
Args:
model_string: The name of the model to use
use_cache: Whether to use caching
system_prompt: The system prompt to use
is_multimodal: Whether to enable multimodal capabilities
**kwargs: Additional arguments to pass to the LiteLLM client
"""
self.model_string = model_string
self.use_cache = use_cache
self.system_prompt = system_prompt
self.is_multimodal = is_multimodal or validate_multimodal_model(model_string)
self.kwargs = kwargs
# Set up caching if enabled
if self.use_cache:
root = platformdirs.user_cache_dir("agentflow")
cache_path = os.path.join(root, f"cache_litellm_{model_string}.db")
self.image_cache_dir = os.path.join(root, "image_cache")
os.makedirs(self.image_cache_dir, exist_ok=True)
super().__init__(cache_path=cache_path)
# Disable telemetry
litellm.telemetry = False
# Set model capabilities based on model name
self.support_structured_output = validate_structured_output_model(self.model_string)
self.is_chat_model = validate_chat_model(self.model_string)
self.is_reasoning_model = validate_reasoning_model(self.model_string)
self.is_pro_reasoning_model = validate_pro_reasoning_model(self.model_string)
# Suppress LiteLLM debug logs
litellm.suppress_debug_info = True
for key in logging.Logger.manager.loggerDict.keys():
if "litellm" in key.lower():
logging.getLogger(key).setLevel(logging.WARNING)
def __call__(self, prompt, **kwargs):
"""
Handle direct calls to the instance (e.g., model(prompt)).
Forwards the call to the generate method.
"""
return self.generate(prompt, **kwargs)
def _format_content(self, content: List[Union[str, bytes]]) -> List[Dict[str, Any]]:
"""
Format content for the LiteLLM API.
Args:
content: List of content items (strings and/or image bytes)
Returns:
Formatted content for the LiteLLM API
"""
formatted_content = []
for item in content:
if isinstance(item, str):
formatted_content.append({"type": "text", "text": item})
elif isinstance(item, bytes):
# For images, encode as base64
image_type = get_image_type_from_bytes(item)
if image_type:
base64_image = base64.b64encode(item).decode('utf-8')
formatted_content.append({
"type": "image_url",
"image_url": {
"url": f"data:image/{image_type};base64,{base64_image}",
"detail": "auto"
}
})
elif isinstance(item, dict) and "type" in item:
# Already formatted content
formatted_content.append(item)
return formatted_content
@retry(wait=wait_random_exponential(min=1, max=5), stop=stop_after_attempt(5))
def generate(self, content: Union[str, List[Union[str, bytes]]], system_prompt=None, **kwargs):
"""
Generate text from a prompt.
Args:
content: A string prompt or a list of strings and image bytes
system_prompt: Optional system prompt to override the default
**kwargs: Additional arguments to pass to the LiteLLM API
Returns:
Generated text response
"""
try:
if isinstance(content, str):
return self._generate_text(content, system_prompt=system_prompt, **kwargs)
elif isinstance(content, list):
has_multimodal_input = any(isinstance(item, bytes) for item in content)
if (has_multimodal_input) and (not self.is_multimodal):
raise NotImplementedError(f"Multimodal generation is only supported for multimodal models. Current model: {self.model_string}")
return self._generate_multimodal(content, system_prompt=system_prompt, **kwargs)
except litellm.exceptions.BadRequestError as e:
print(f"Bad request error: {str(e)}")
return {
"error": "bad_request",
"message": str(e),
"details": getattr(e, 'args', None)
}
except litellm.exceptions.RateLimitError as e:
print(f"Rate limit error encountered: {str(e)}")
return {
"error": "rate_limit",
"message": str(e),
"details": getattr(e, 'args', None)
}
except litellm.exceptions.ContextWindowExceededError as e:
print(f"Context window exceeded: {str(e)}")
return {
"error": "context_window_exceeded",
"message": str(e),
"details": getattr(e, 'args', None)
}
except litellm.exceptions.APIError as e:
print(f"API error: {str(e)}")
return {
"error": "api_error",
"message": str(e),
"details": getattr(e, 'args', None)
}
except litellm.exceptions.APIConnectionError as e:
print(f"API connection error: {str(e)}")
return {
"error": "api_connection_error",
"message": str(e),
"details": getattr(e, 'args', None)
}
except Exception as e:
print(f"Error in generate method: {str(e)}")
print(f"Error type: {type(e).__name__}")
print(f"Error details: {e.args}")
return {
"error": type(e).__name__,
"message": str(e),
"details": getattr(e, 'args', None)
}
def _generate_text(
self, prompt, system_prompt=None, temperature=0, max_tokens=4000, top_p=0.99, response_format=None, **kwargs
):
"""
Generate text from a text prompt.
Args:
prompt: The text prompt
system_prompt: Optional system prompt to override the default
temperature: Controls randomness (higher = more random)
max_tokens: Maximum number of tokens to generate
top_p: Controls diversity via nucleus sampling
response_format: Optional response format for structured outputs
**kwargs: Additional arguments to pass to the LiteLLM API
Returns:
Generated text response
"""
sys_prompt_arg = system_prompt if system_prompt else self.system_prompt
if self.use_cache:
cache_key = sys_prompt_arg + prompt
cache_or_none = self._check_cache(cache_key)
if cache_or_none is not None:
return cache_or_none
messages = [
{"role": "system", "content": sys_prompt_arg},
{"role": "user", "content": prompt},
]
# Prepare additional parameters
params = {
"temperature": temperature,
"max_tokens": max_tokens,
"top_p": top_p,
}
# Add response_format if supported and provided
if self.support_structured_output and response_format:
params["response_format"] = response_format
# Add any additional kwargs
params.update(self.kwargs)
params.update(kwargs)
# Make the API call
response = litellm.completion(
model=self.model_string,
messages=messages,
**params
)
response_text = response.choices[0].message.content
if self.use_cache:
self._save_cache(cache_key, response_text)
return response_text
def _generate_multimodal(
self, content_list, system_prompt=None, temperature=0, max_tokens=4000, top_p=0.99, **kwargs
):
"""
Generate text from a multimodal prompt (text and images).
Args:
content_list: List of content items (strings and/or image bytes)
system_prompt: Optional system prompt to override the default
temperature: Controls randomness (higher = more random)
max_tokens: Maximum number of tokens to generate
top_p: Controls diversity via nucleus sampling
**kwargs: Additional arguments to pass to the LiteLLM API
Returns:
Generated text response
"""
sys_prompt_arg = system_prompt if system_prompt else self.system_prompt
formatted_content = self._format_content(content_list)
if self.use_cache:
cache_key = sys_prompt_arg + json.dumps(str(formatted_content))
cache_or_none = self._check_cache(cache_key)
if cache_or_none is not None:
return cache_or_none
messages = [
{"role": "system", "content": sys_prompt_arg},
{"role": "user", "content": formatted_content},
]
# Prepare additional parameters
params = {
"temperature": temperature,
"max_tokens": max_tokens,
"top_p": top_p,
}
# Add any additional kwargs
params.update(self.kwargs)
params.update(kwargs)
# Make the API call
response = litellm.completion(
model=self.model_string,
messages=messages,
**params
)
response_text = response.choices[0].message.content
if self.use_cache:
self._save_cache(cache_key, response_text)
return response_text