File size: 10,157 Bytes
5374a2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import os
import litellm
from tenacity import (
    retry,
    stop_after_attempt,
    wait_random_exponential,
)
from litellm import completion, acompletion
from typing import List
from ..core.registry import register_model
from .model_configs import LiteLLMConfig
from .openai_model import OpenAILLM
from .model_utils import infer_litellm_company_from_model, Cost

@register_model(config_cls=LiteLLMConfig, alias=["litellm"])
class LiteLLM(OpenAILLM):

    def init_model(self):
        """
        Initialize the model based on the configuration.
        """
        # Check if llm_type is correct
        if self.config.llm_type != "LiteLLM":
            raise ValueError("llm_type must be 'LiteLLM'")

        # Set model and extract the company name
        self.model = self.config.model
        self.api_base = self.config.api_base  # save api_base
        self.api_key = self.config.api_key
        # company = self.model.split("/")[0] if "/" in self.model else "openai"
        company = infer_litellm_company_from_model(self.model)

        if self.config.is_local or company == "local":  # update support local model
            if not self.api_base:
                raise ValueError("api_base is required for local models in LiteLLMConfig")
            # local llm doesn't need API key
            litellm.api_base = self.api_base  # set litellm global api_base
            litellm.api_key = self.api_key
        else:
            # Set environment variables based on the company
            if company == "openai":
                if not self.config.openai_key:
                    raise ValueError("OpenAI API key is required for OpenAI models. You should set `openai_key` in LiteLLMConfig")
                os.environ["OPENAI_API_KEY"] = self.config.openai_key
            elif company == "azure":
                if not self.config.azure_key or not self.config.azure_endpoint:
                    raise ValueError("Azure OpenAI key and endpoint are required for Azure models. You should set `azure_key` and `azure_endpoint` in LiteLLMConfig")
                os.environ["AZURE_API_KEY"] = self.config.azure_key
                os.environ["AZURE_API_BASE"] = self.config.azure_endpoint
                if self.config.api_version:
                    os.environ["AZURE_API_VERSION"] = self.config.api_version
            elif company == "deepseek":
                if not self.config.deepseek_key:
                    raise ValueError("DeepSeek API key is required for DeepSeek models. You should set `deepseek_key` in LiteLLMConfig")
                os.environ["DEEPSEEK_API_KEY"] = self.config.deepseek_key
            elif company == "anthropic":
                if not self.config.anthropic_key:
                    raise ValueError("Anthropic API key is required for Anthropic models. You should set `anthropic_key` in LiteLLMConfig")
                os.environ["ANTHROPIC_API_KEY"] = self.config.anthropic_key
            elif company == "gemini":
                if not self.config.gemini_key:
                    raise ValueError("Gemini API key is required for Gemini models. You should set `gemini_key` in LiteLLMConfig")
                os.environ["GEMINI_API_KEY"] = self.config.gemini_key 
            elif company == "meta_llama":
                if not self.config.meta_llama_key:
                    raise ValueError("Meta Llama API key is required for Meta Llama models. You should set `meta_llama_key` in LiteLLMConfig")
                os.environ["LLAMA_API_KEY"] = self.config.meta_llama_key
            elif company == "openrouter":
                if not self.config.openrouter_key:
                    raise ValueError("OpenRouter API key is required for OpenRouter models. You should set `openrouter_key` in LiteLLMConfig. You can also set `openrouter_base` in LiteLLMConfig to use a custom base URL [optional]")
                os.environ["OPENROUTER_API_KEY"] = self.config.openrouter_key
                os.environ["OPENROUTER_API_BASE"] = self.config.openrouter_base # [optional]
            elif company == "perplexity":
                if not self.config.perplexity_key:
                    raise ValueError("Perplexity API key is required for Perplexity models. You should set `perplexity_key` in LiteLLMConfig")
                os.environ["PERPLEXITYAI_API_KEY"] = self.config.perplexity_key
            elif company == "groq":
                if not self.config.groq_key:
                    raise ValueError("Groq API key is required for Groq models. You should set `groq_key` in LiteLLMConfig")
                os.environ["GROQ_API_KEY"] = self.config.groq_key
            else:
                raise ValueError(f"Unsupported company: {company}")

        self._default_ignore_fields = [
            "llm_type", "output_response", "openai_key", "deepseek_key", "anthropic_key", 
            "gemini_key", "meta_llama_key", "openrouter_key", "openrouter_base", "perplexity_key", 
            "groq_key", "api_base", "is_local", "azure_endpoint", "azure_key", "api_version", "api_key"
        ] # parameters in LiteLLMConfig that are not LiteLLM models' input parameters 
    
    def _compute_cost(self, input_tokens: int, output_tokens: int) -> Cost:
        if self.config.is_local:
            return Cost(input_tokens=input_tokens, output_tokens=output_tokens, input_cost=0.0, output_cost=0.0)
        return super()._compute_cost(input_tokens, output_tokens)

    @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(5))
    def single_generate(self, messages: List[dict], **kwargs) -> str:

        """
        Generate a single response using the completion function.

        Args: 
            messages (List[dict]): A list of dictionaries representing the conversation history.
            **kwargs (Any): Additional parameters to be passed to the `completion` function.
        
        Returns: 
            str: A string containing the model's response.
        """
        stream = kwargs["stream"] if "stream" in kwargs else self.config.stream
        output_response = kwargs["output_response"] if "output_response" in kwargs else self.config.output_response

        try:
            completion_params = self.get_completion_params(**kwargs)
            company = infer_litellm_company_from_model(self.model)
            if self.config.is_local or company == "local":  # update save api_base for local model
                completion_params["api_base"] = self.api_base
            elif company == "azure":  # Add Azure OpenAI specific parameters
                completion_params["api_base"] = self.config.azure_endpoint
                completion_params["api_version"] = self.config.api_version
                completion_params["api_key"] = self.config.azure_key
            response = completion(messages=messages, **completion_params)
            if stream:
                output = self.get_stream_output(response, output_response=output_response)
                cost = self._stream_cost(messages=messages, output=output)
            else:
                output: str = self.get_completion_output(response=response, output_response=output_response)
                cost = self._completion_cost(response=response)
            self._update_cost(cost=cost)

        except Exception as e:
            raise RuntimeError(f"Error during single_generate: {str(e)}")
        
        return output
    
    def batch_generate(self, batch_messages: List[List[dict]], **kwargs) -> List[str]:
        """
        Generate responses for a batch of messages.

        Args: 
            batch_messages (List[List[dict]]): A list of message lists, where each sublist represents a conversation.
            **kwargs (Any): Additional parameters to be passed to the `completion` function.
        
        Returns: 
            List[str]: A list of responses for each conversation.
        """
        results = []
        for messages in batch_messages:
            response = self.single_generate(messages, **kwargs)
            results.append(response)
        return results
    
    async def single_generate_async(self, messages: List[dict], **kwargs) -> str:
        """
        Generate a single response using the async completion function.

        Args: 
            messages (List[dict]): A list of dictionaries representing the conversation history.
            **kwargs (Any): Additional parameters to be passed to the `completion` function.
        
        Returns: 
            str: A string containing the model's response.
        """
        stream = kwargs["stream"] if "stream" in kwargs else self.config.stream
        output_response = kwargs["output_response"] if "output_response" in kwargs else self.config.output_response

        try:
            completion_params = self.get_completion_params(**kwargs)
            company = infer_litellm_company_from_model(self.model)
            if self.config.is_local or company == "local":  # add api base for local model
                completion_params["api_base"] = self.api_base
            elif company == "azure":  # Add Azure OpenAI specific parameters
                completion_params["api_base"] = self.config.azure_endpoint
                completion_params["api_version"] = self.config.api_version
                completion_params["api_key"] = self.config.azure_key
            response = await acompletion(messages=messages, **completion_params)
            if stream:
                if hasattr(response, "__aiter__"):
                    output = await self.get_stream_output_async(response, output_response=output_response)
                else:
                    output = self.get_stream_output(response, output_response=output_response)
                cost = self._stream_cost(messages=messages, output=output)
            else:
                output: str = self.get_completion_output(response=response, output_response=output_response)
                cost = self._completion_cost(response=response)
            self._update_cost(cost=cost)
        except Exception as e:
            raise RuntimeError(f"Error during single_generate_async: {str(e)}")
        
        return output