File size: 10,157 Bytes
5374a2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
import os
import litellm
from tenacity import (
retry,
stop_after_attempt,
wait_random_exponential,
)
from litellm import completion, acompletion
from typing import List
from ..core.registry import register_model
from .model_configs import LiteLLMConfig
from .openai_model import OpenAILLM
from .model_utils import infer_litellm_company_from_model, Cost
@register_model(config_cls=LiteLLMConfig, alias=["litellm"])
class LiteLLM(OpenAILLM):
def init_model(self):
"""
Initialize the model based on the configuration.
"""
# Check if llm_type is correct
if self.config.llm_type != "LiteLLM":
raise ValueError("llm_type must be 'LiteLLM'")
# Set model and extract the company name
self.model = self.config.model
self.api_base = self.config.api_base # save api_base
self.api_key = self.config.api_key
# company = self.model.split("/")[0] if "/" in self.model else "openai"
company = infer_litellm_company_from_model(self.model)
if self.config.is_local or company == "local": # update support local model
if not self.api_base:
raise ValueError("api_base is required for local models in LiteLLMConfig")
# local llm doesn't need API key
litellm.api_base = self.api_base # set litellm global api_base
litellm.api_key = self.api_key
else:
# Set environment variables based on the company
if company == "openai":
if not self.config.openai_key:
raise ValueError("OpenAI API key is required for OpenAI models. You should set `openai_key` in LiteLLMConfig")
os.environ["OPENAI_API_KEY"] = self.config.openai_key
elif company == "azure":
if not self.config.azure_key or not self.config.azure_endpoint:
raise ValueError("Azure OpenAI key and endpoint are required for Azure models. You should set `azure_key` and `azure_endpoint` in LiteLLMConfig")
os.environ["AZURE_API_KEY"] = self.config.azure_key
os.environ["AZURE_API_BASE"] = self.config.azure_endpoint
if self.config.api_version:
os.environ["AZURE_API_VERSION"] = self.config.api_version
elif company == "deepseek":
if not self.config.deepseek_key:
raise ValueError("DeepSeek API key is required for DeepSeek models. You should set `deepseek_key` in LiteLLMConfig")
os.environ["DEEPSEEK_API_KEY"] = self.config.deepseek_key
elif company == "anthropic":
if not self.config.anthropic_key:
raise ValueError("Anthropic API key is required for Anthropic models. You should set `anthropic_key` in LiteLLMConfig")
os.environ["ANTHROPIC_API_KEY"] = self.config.anthropic_key
elif company == "gemini":
if not self.config.gemini_key:
raise ValueError("Gemini API key is required for Gemini models. You should set `gemini_key` in LiteLLMConfig")
os.environ["GEMINI_API_KEY"] = self.config.gemini_key
elif company == "meta_llama":
if not self.config.meta_llama_key:
raise ValueError("Meta Llama API key is required for Meta Llama models. You should set `meta_llama_key` in LiteLLMConfig")
os.environ["LLAMA_API_KEY"] = self.config.meta_llama_key
elif company == "openrouter":
if not self.config.openrouter_key:
raise ValueError("OpenRouter API key is required for OpenRouter models. You should set `openrouter_key` in LiteLLMConfig. You can also set `openrouter_base` in LiteLLMConfig to use a custom base URL [optional]")
os.environ["OPENROUTER_API_KEY"] = self.config.openrouter_key
os.environ["OPENROUTER_API_BASE"] = self.config.openrouter_base # [optional]
elif company == "perplexity":
if not self.config.perplexity_key:
raise ValueError("Perplexity API key is required for Perplexity models. You should set `perplexity_key` in LiteLLMConfig")
os.environ["PERPLEXITYAI_API_KEY"] = self.config.perplexity_key
elif company == "groq":
if not self.config.groq_key:
raise ValueError("Groq API key is required for Groq models. You should set `groq_key` in LiteLLMConfig")
os.environ["GROQ_API_KEY"] = self.config.groq_key
else:
raise ValueError(f"Unsupported company: {company}")
self._default_ignore_fields = [
"llm_type", "output_response", "openai_key", "deepseek_key", "anthropic_key",
"gemini_key", "meta_llama_key", "openrouter_key", "openrouter_base", "perplexity_key",
"groq_key", "api_base", "is_local", "azure_endpoint", "azure_key", "api_version", "api_key"
] # parameters in LiteLLMConfig that are not LiteLLM models' input parameters
def _compute_cost(self, input_tokens: int, output_tokens: int) -> Cost:
if self.config.is_local:
return Cost(input_tokens=input_tokens, output_tokens=output_tokens, input_cost=0.0, output_cost=0.0)
return super()._compute_cost(input_tokens, output_tokens)
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(5))
def single_generate(self, messages: List[dict], **kwargs) -> str:
"""
Generate a single response using the completion function.
Args:
messages (List[dict]): A list of dictionaries representing the conversation history.
**kwargs (Any): Additional parameters to be passed to the `completion` function.
Returns:
str: A string containing the model's response.
"""
stream = kwargs["stream"] if "stream" in kwargs else self.config.stream
output_response = kwargs["output_response"] if "output_response" in kwargs else self.config.output_response
try:
completion_params = self.get_completion_params(**kwargs)
company = infer_litellm_company_from_model(self.model)
if self.config.is_local or company == "local": # update save api_base for local model
completion_params["api_base"] = self.api_base
elif company == "azure": # Add Azure OpenAI specific parameters
completion_params["api_base"] = self.config.azure_endpoint
completion_params["api_version"] = self.config.api_version
completion_params["api_key"] = self.config.azure_key
response = completion(messages=messages, **completion_params)
if stream:
output = self.get_stream_output(response, output_response=output_response)
cost = self._stream_cost(messages=messages, output=output)
else:
output: str = self.get_completion_output(response=response, output_response=output_response)
cost = self._completion_cost(response=response)
self._update_cost(cost=cost)
except Exception as e:
raise RuntimeError(f"Error during single_generate: {str(e)}")
return output
def batch_generate(self, batch_messages: List[List[dict]], **kwargs) -> List[str]:
"""
Generate responses for a batch of messages.
Args:
batch_messages (List[List[dict]]): A list of message lists, where each sublist represents a conversation.
**kwargs (Any): Additional parameters to be passed to the `completion` function.
Returns:
List[str]: A list of responses for each conversation.
"""
results = []
for messages in batch_messages:
response = self.single_generate(messages, **kwargs)
results.append(response)
return results
async def single_generate_async(self, messages: List[dict], **kwargs) -> str:
"""
Generate a single response using the async completion function.
Args:
messages (List[dict]): A list of dictionaries representing the conversation history.
**kwargs (Any): Additional parameters to be passed to the `completion` function.
Returns:
str: A string containing the model's response.
"""
stream = kwargs["stream"] if "stream" in kwargs else self.config.stream
output_response = kwargs["output_response"] if "output_response" in kwargs else self.config.output_response
try:
completion_params = self.get_completion_params(**kwargs)
company = infer_litellm_company_from_model(self.model)
if self.config.is_local or company == "local": # add api base for local model
completion_params["api_base"] = self.api_base
elif company == "azure": # Add Azure OpenAI specific parameters
completion_params["api_base"] = self.config.azure_endpoint
completion_params["api_version"] = self.config.api_version
completion_params["api_key"] = self.config.azure_key
response = await acompletion(messages=messages, **completion_params)
if stream:
if hasattr(response, "__aiter__"):
output = await self.get_stream_output_async(response, output_response=output_response)
else:
output = self.get_stream_output(response, output_response=output_response)
cost = self._stream_cost(messages=messages, output=output)
else:
output: str = self.get_completion_output(response=response, output_response=output_response)
cost = self._completion_cost(response=response)
self._update_cost(cost=cost)
except Exception as e:
raise RuntimeError(f"Error during single_generate_async: {str(e)}")
return output
|