selfevolveagent / evoagentx /models /litellm_model.py
iLOVE2D's picture
Upload 2846 files
5374a2d verified
import os
import litellm
from tenacity import (
retry,
stop_after_attempt,
wait_random_exponential,
)
from litellm import completion, acompletion
from typing import List
from ..core.registry import register_model
from .model_configs import LiteLLMConfig
from .openai_model import OpenAILLM
from .model_utils import infer_litellm_company_from_model, Cost
@register_model(config_cls=LiteLLMConfig, alias=["litellm"])
class LiteLLM(OpenAILLM):
def init_model(self):
"""
Initialize the model based on the configuration.
"""
# Check if llm_type is correct
if self.config.llm_type != "LiteLLM":
raise ValueError("llm_type must be 'LiteLLM'")
# Set model and extract the company name
self.model = self.config.model
self.api_base = self.config.api_base # save api_base
self.api_key = self.config.api_key
# company = self.model.split("/")[0] if "/" in self.model else "openai"
company = infer_litellm_company_from_model(self.model)
if self.config.is_local or company == "local": # update support local model
if not self.api_base:
raise ValueError("api_base is required for local models in LiteLLMConfig")
# local llm doesn't need API key
litellm.api_base = self.api_base # set litellm global api_base
litellm.api_key = self.api_key
else:
# Set environment variables based on the company
if company == "openai":
if not self.config.openai_key:
raise ValueError("OpenAI API key is required for OpenAI models. You should set `openai_key` in LiteLLMConfig")
os.environ["OPENAI_API_KEY"] = self.config.openai_key
elif company == "azure":
if not self.config.azure_key or not self.config.azure_endpoint:
raise ValueError("Azure OpenAI key and endpoint are required for Azure models. You should set `azure_key` and `azure_endpoint` in LiteLLMConfig")
os.environ["AZURE_API_KEY"] = self.config.azure_key
os.environ["AZURE_API_BASE"] = self.config.azure_endpoint
if self.config.api_version:
os.environ["AZURE_API_VERSION"] = self.config.api_version
elif company == "deepseek":
if not self.config.deepseek_key:
raise ValueError("DeepSeek API key is required for DeepSeek models. You should set `deepseek_key` in LiteLLMConfig")
os.environ["DEEPSEEK_API_KEY"] = self.config.deepseek_key
elif company == "anthropic":
if not self.config.anthropic_key:
raise ValueError("Anthropic API key is required for Anthropic models. You should set `anthropic_key` in LiteLLMConfig")
os.environ["ANTHROPIC_API_KEY"] = self.config.anthropic_key
elif company == "gemini":
if not self.config.gemini_key:
raise ValueError("Gemini API key is required for Gemini models. You should set `gemini_key` in LiteLLMConfig")
os.environ["GEMINI_API_KEY"] = self.config.gemini_key
elif company == "meta_llama":
if not self.config.meta_llama_key:
raise ValueError("Meta Llama API key is required for Meta Llama models. You should set `meta_llama_key` in LiteLLMConfig")
os.environ["LLAMA_API_KEY"] = self.config.meta_llama_key
elif company == "openrouter":
if not self.config.openrouter_key:
raise ValueError("OpenRouter API key is required for OpenRouter models. You should set `openrouter_key` in LiteLLMConfig. You can also set `openrouter_base` in LiteLLMConfig to use a custom base URL [optional]")
os.environ["OPENROUTER_API_KEY"] = self.config.openrouter_key
os.environ["OPENROUTER_API_BASE"] = self.config.openrouter_base # [optional]
elif company == "perplexity":
if not self.config.perplexity_key:
raise ValueError("Perplexity API key is required for Perplexity models. You should set `perplexity_key` in LiteLLMConfig")
os.environ["PERPLEXITYAI_API_KEY"] = self.config.perplexity_key
elif company == "groq":
if not self.config.groq_key:
raise ValueError("Groq API key is required for Groq models. You should set `groq_key` in LiteLLMConfig")
os.environ["GROQ_API_KEY"] = self.config.groq_key
else:
raise ValueError(f"Unsupported company: {company}")
self._default_ignore_fields = [
"llm_type", "output_response", "openai_key", "deepseek_key", "anthropic_key",
"gemini_key", "meta_llama_key", "openrouter_key", "openrouter_base", "perplexity_key",
"groq_key", "api_base", "is_local", "azure_endpoint", "azure_key", "api_version", "api_key"
] # parameters in LiteLLMConfig that are not LiteLLM models' input parameters
def _compute_cost(self, input_tokens: int, output_tokens: int) -> Cost:
if self.config.is_local:
return Cost(input_tokens=input_tokens, output_tokens=output_tokens, input_cost=0.0, output_cost=0.0)
return super()._compute_cost(input_tokens, output_tokens)
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(5))
def single_generate(self, messages: List[dict], **kwargs) -> str:
"""
Generate a single response using the completion function.
Args:
messages (List[dict]): A list of dictionaries representing the conversation history.
**kwargs (Any): Additional parameters to be passed to the `completion` function.
Returns:
str: A string containing the model's response.
"""
stream = kwargs["stream"] if "stream" in kwargs else self.config.stream
output_response = kwargs["output_response"] if "output_response" in kwargs else self.config.output_response
try:
completion_params = self.get_completion_params(**kwargs)
company = infer_litellm_company_from_model(self.model)
if self.config.is_local or company == "local": # update save api_base for local model
completion_params["api_base"] = self.api_base
elif company == "azure": # Add Azure OpenAI specific parameters
completion_params["api_base"] = self.config.azure_endpoint
completion_params["api_version"] = self.config.api_version
completion_params["api_key"] = self.config.azure_key
response = completion(messages=messages, **completion_params)
if stream:
output = self.get_stream_output(response, output_response=output_response)
cost = self._stream_cost(messages=messages, output=output)
else:
output: str = self.get_completion_output(response=response, output_response=output_response)
cost = self._completion_cost(response=response)
self._update_cost(cost=cost)
except Exception as e:
raise RuntimeError(f"Error during single_generate: {str(e)}")
return output
def batch_generate(self, batch_messages: List[List[dict]], **kwargs) -> List[str]:
"""
Generate responses for a batch of messages.
Args:
batch_messages (List[List[dict]]): A list of message lists, where each sublist represents a conversation.
**kwargs (Any): Additional parameters to be passed to the `completion` function.
Returns:
List[str]: A list of responses for each conversation.
"""
results = []
for messages in batch_messages:
response = self.single_generate(messages, **kwargs)
results.append(response)
return results
async def single_generate_async(self, messages: List[dict], **kwargs) -> str:
"""
Generate a single response using the async completion function.
Args:
messages (List[dict]): A list of dictionaries representing the conversation history.
**kwargs (Any): Additional parameters to be passed to the `completion` function.
Returns:
str: A string containing the model's response.
"""
stream = kwargs["stream"] if "stream" in kwargs else self.config.stream
output_response = kwargs["output_response"] if "output_response" in kwargs else self.config.output_response
try:
completion_params = self.get_completion_params(**kwargs)
company = infer_litellm_company_from_model(self.model)
if self.config.is_local or company == "local": # add api base for local model
completion_params["api_base"] = self.api_base
elif company == "azure": # Add Azure OpenAI specific parameters
completion_params["api_base"] = self.config.azure_endpoint
completion_params["api_version"] = self.config.api_version
completion_params["api_key"] = self.config.azure_key
response = await acompletion(messages=messages, **completion_params)
if stream:
if hasattr(response, "__aiter__"):
output = await self.get_stream_output_async(response, output_response=output_response)
else:
output = self.get_stream_output(response, output_response=output_response)
cost = self._stream_cost(messages=messages, output=output)
else:
output: str = self.get_completion_output(response=response, output_response=output_response)
cost = self._completion_cost(response=response)
self._update_cost(cost=cost)
except Exception as e:
raise RuntimeError(f"Error during single_generate_async: {str(e)}")
return output