# TODO: The current implementation is not based on textgrad, but rather a direct implementation of the LiteLLM API. # Detached from textgrad: https://github.com/zou-group/textgrad/blob/main/textgrad/engine_experimental/litellm.py try: import litellm from litellm import supports_reasoning except ImportError: raise ImportError("If you'd like to use LiteLLM, please install the litellm package by running `pip install litellm`, and set appropriate API keys for the models you want to use.") import os import json import base64 import platformdirs import logging from tenacity import ( retry, stop_after_attempt, wait_random_exponential, ) from typing import List, Union, Optional, Any, Dict from .base import EngineLM, CachedEngine from .engine_utils import get_image_type_from_bytes def validate_structured_output_model(model_string: str) -> bool: """ Check if the model supports structured outputs. Args: model_string: The name of the model to check Returns: True if the model supports structured outputs, False otherwise """ # Models that support structured outputs structure_output_models = [ "gpt-4", "claude-opus-4", "claude-sonnet-4", "claude-3.7-sonnet", "claude-3.5-sonnet", "claude-3-opus", "gemini-", ] return any(x in model_string.lower() for x in structure_output_models) def validate_chat_model(model_string: str) -> bool: # 99% of LiteLLM models are chat models return True def validate_reasoning_model(model_string: str) -> bool: """ Check if the model is a reasoning model. Includes OpenAI o1/o3/o4 variants (non-pro), Claude models, and other LLMs known for reasoning. """ m = model_string.lower() if supports_reasoning(model_string): return True # Hard ways if any(x in m for x in ["o1", "o3", "o4"]) and not validate_pro_reasoning_model(model_string): return True if "claude" in m and not validate_pro_reasoning_model(model_string): return True extra = ["qwen-72b", "llama-3-70b", "mistral-large", "deepseek-reasoner", "xai/grok-3", "gemini-2.5-pro"] if any(e in model_string.lower() for e in extra): return True return False def validate_pro_reasoning_model(model_string: str) -> bool: """ Check if the model is a pro reasoning model: OpenAI o1-pro, o3-pro, o4-pro, and Claude-4/Sonnet variants. """ m = model_string.lower() if any(x in m for x in ["o1-pro", "o3-pro", "o4-pro"]): return True if any(x in m for x in ["claude-opus-4", "claude-sonnet-4", "claude-3.7-sonnet"]): return True return False def validate_multimodal_model(model_string: str) -> bool: """ Check if the model supports multimodal inputs. Args: model_string: The name of the model to check Returns: True if the model supports multimodal inputs, False otherwise """ m = model_string.lower() # Core multimodal models multimodal_models = [ "gpt-4-vision", "gpt-4o", "gpt-4.1", # OpenAI multimodal "gpt-4v", # alias for vision-capable GPT-4 "claude-sonnet", "claude-opus", # Claude multimodal variants "gemini", # Base Gemini models are multimodal :contentReference[oaicite:0]{index=0} "gpt-4v", # repeats for clarity "llama-4", # reported as multimodal "qwen-vl", "qwen2-vl", # Qwen vision-language models ] # Add Gemini TTS / audio-capable variants (though audio is modality) audio_models = ["-tts", "-flash-preview-tts", "-pro-preview-tts"] if any(g in m for g in multimodal_models): return True if "gemini" in m and any(s in m for s in audio_models): return True # E.g. gemini-2.5-flash-preview-tts # Make sure we catch edge cases like "gpt-4v" or "gpt-4 vision" if "vision" in m or "vl" in m: return True return False class ChatLiteLLM(EngineLM, CachedEngine): """ LiteLLM implementation of the EngineLM interface. This allows using any model supported by LiteLLM. """ DEFAULT_SYSTEM_PROMPT = "You are a helpful, creative, and smart assistant." def __init__( self, model_string: str = "gpt-3.5-turbo", use_cache: bool = False, system_prompt: str = DEFAULT_SYSTEM_PROMPT, is_multimodal: bool = False, **kwargs ): """ Initialize the LiteLLM engine. Args: model_string: The name of the model to use use_cache: Whether to use caching system_prompt: The system prompt to use is_multimodal: Whether to enable multimodal capabilities **kwargs: Additional arguments to pass to the LiteLLM client """ self.model_string = model_string self.use_cache = use_cache self.system_prompt = system_prompt self.is_multimodal = is_multimodal or validate_multimodal_model(model_string) self.kwargs = kwargs # Set up caching if enabled if self.use_cache: root = platformdirs.user_cache_dir("agentflow") cache_path = os.path.join(root, f"cache_litellm_{model_string}.db") self.image_cache_dir = os.path.join(root, "image_cache") os.makedirs(self.image_cache_dir, exist_ok=True) super().__init__(cache_path=cache_path) # Disable telemetry litellm.telemetry = False # Set model capabilities based on model name self.support_structured_output = validate_structured_output_model(self.model_string) self.is_chat_model = validate_chat_model(self.model_string) self.is_reasoning_model = validate_reasoning_model(self.model_string) self.is_pro_reasoning_model = validate_pro_reasoning_model(self.model_string) # Suppress LiteLLM debug logs litellm.suppress_debug_info = True for key in logging.Logger.manager.loggerDict.keys(): if "litellm" in key.lower(): logging.getLogger(key).setLevel(logging.WARNING) def __call__(self, prompt, **kwargs): """ Handle direct calls to the instance (e.g., model(prompt)). Forwards the call to the generate method. """ return self.generate(prompt, **kwargs) def _format_content(self, content: List[Union[str, bytes]]) -> List[Dict[str, Any]]: """ Format content for the LiteLLM API. Args: content: List of content items (strings and/or image bytes) Returns: Formatted content for the LiteLLM API """ formatted_content = [] for item in content: if isinstance(item, str): formatted_content.append({"type": "text", "text": item}) elif isinstance(item, bytes): # For images, encode as base64 image_type = get_image_type_from_bytes(item) if image_type: base64_image = base64.b64encode(item).decode('utf-8') formatted_content.append({ "type": "image_url", "image_url": { "url": f"data:image/{image_type};base64,{base64_image}", "detail": "auto" } }) elif isinstance(item, dict) and "type" in item: # Already formatted content formatted_content.append(item) return formatted_content @retry(wait=wait_random_exponential(min=1, max=5), stop=stop_after_attempt(5)) def generate(self, content: Union[str, List[Union[str, bytes]]], system_prompt=None, **kwargs): """ Generate text from a prompt. Args: content: A string prompt or a list of strings and image bytes system_prompt: Optional system prompt to override the default **kwargs: Additional arguments to pass to the LiteLLM API Returns: Generated text response """ try: if isinstance(content, str): return self._generate_text(content, system_prompt=system_prompt, **kwargs) elif isinstance(content, list): has_multimodal_input = any(isinstance(item, bytes) for item in content) if (has_multimodal_input) and (not self.is_multimodal): raise NotImplementedError(f"Multimodal generation is only supported for multimodal models. Current model: {self.model_string}") return self._generate_multimodal(content, system_prompt=system_prompt, **kwargs) except litellm.exceptions.BadRequestError as e: print(f"Bad request error: {str(e)}") return { "error": "bad_request", "message": str(e), "details": getattr(e, 'args', None) } except litellm.exceptions.RateLimitError as e: print(f"Rate limit error encountered: {str(e)}") return { "error": "rate_limit", "message": str(e), "details": getattr(e, 'args', None) } except litellm.exceptions.ContextWindowExceededError as e: print(f"Context window exceeded: {str(e)}") return { "error": "context_window_exceeded", "message": str(e), "details": getattr(e, 'args', None) } except litellm.exceptions.APIError as e: print(f"API error: {str(e)}") return { "error": "api_error", "message": str(e), "details": getattr(e, 'args', None) } except litellm.exceptions.APIConnectionError as e: print(f"API connection error: {str(e)}") return { "error": "api_connection_error", "message": str(e), "details": getattr(e, 'args', None) } except Exception as e: print(f"Error in generate method: {str(e)}") print(f"Error type: {type(e).__name__}") print(f"Error details: {e.args}") return { "error": type(e).__name__, "message": str(e), "details": getattr(e, 'args', None) } def _generate_text( self, prompt, system_prompt=None, temperature=0, max_tokens=4000, top_p=0.99, response_format=None, **kwargs ): """ Generate text from a text prompt. Args: prompt: The text prompt system_prompt: Optional system prompt to override the default temperature: Controls randomness (higher = more random) max_tokens: Maximum number of tokens to generate top_p: Controls diversity via nucleus sampling response_format: Optional response format for structured outputs **kwargs: Additional arguments to pass to the LiteLLM API Returns: Generated text response """ sys_prompt_arg = system_prompt if system_prompt else self.system_prompt if self.use_cache: cache_key = sys_prompt_arg + prompt cache_or_none = self._check_cache(cache_key) if cache_or_none is not None: return cache_or_none messages = [ {"role": "system", "content": sys_prompt_arg}, {"role": "user", "content": prompt}, ] # Prepare additional parameters params = { "temperature": temperature, "max_tokens": max_tokens, "top_p": top_p, } # Add response_format if supported and provided if self.support_structured_output and response_format: params["response_format"] = response_format # Add any additional kwargs params.update(self.kwargs) params.update(kwargs) # Make the API call response = litellm.completion( model=self.model_string, messages=messages, **params ) response_text = response.choices[0].message.content if self.use_cache: self._save_cache(cache_key, response_text) return response_text def _generate_multimodal( self, content_list, system_prompt=None, temperature=0, max_tokens=4000, top_p=0.99, **kwargs ): """ Generate text from a multimodal prompt (text and images). Args: content_list: List of content items (strings and/or image bytes) system_prompt: Optional system prompt to override the default temperature: Controls randomness (higher = more random) max_tokens: Maximum number of tokens to generate top_p: Controls diversity via nucleus sampling **kwargs: Additional arguments to pass to the LiteLLM API Returns: Generated text response """ sys_prompt_arg = system_prompt if system_prompt else self.system_prompt formatted_content = self._format_content(content_list) if self.use_cache: cache_key = sys_prompt_arg + json.dumps(str(formatted_content)) cache_or_none = self._check_cache(cache_key) if cache_or_none is not None: return cache_or_none messages = [ {"role": "system", "content": sys_prompt_arg}, {"role": "user", "content": formatted_content}, ] # Prepare additional parameters params = { "temperature": temperature, "max_tokens": max_tokens, "top_p": top_p, } # Add any additional kwargs params.update(self.kwargs) params.update(kwargs) # Make the API call response = litellm.completion( model=self.model_string, messages=messages, **params ) response_text = response.choices[0].message.content if self.use_cache: self._save_cache(cache_key, response_text) return response_text