File size: 8,797 Bytes
5374a2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
import asyncio
from tenacity import (
retry,
stop_after_attempt,
wait_random_exponential,
)
from openai import OpenAI, Stream
from openai.types.chat import ChatCompletion
from typing import Optional, List
from litellm import token_counter, cost_per_token
from ..core.registry import register_model
from .model_configs import OpenAILLMConfig
from .base_model import BaseLLM
from .model_utils import Cost, cost_manager, get_openai_model_cost
@register_model(config_cls=OpenAILLMConfig, alias=["openai_llm"])
class OpenAILLM(BaseLLM):
def init_model(self):
config: OpenAILLMConfig = self.config
self._client = self._init_client(config) # OpenAI(api_key=config.openai_key)
self._default_ignore_fields = [
"llm_type", "output_response", "openai_key", "deepseek_key", "anthropic_key",
"gemini_key", "meta_llama_key", "openrouter_key", "openrouter_base", "perplexity_key",
"groq_key"
] # parameters in OpenAILLMConfig that are not OpenAI models' input parameters
if self.config.model not in get_openai_model_cost():
raise KeyError(f"'{self.config.model}' is not a valid OpenAI model name!")
def _init_client(self, config: OpenAILLMConfig):
client = OpenAI(api_key=config.openai_key)
return client
def formulate_messages(self, prompts: List[str], system_messages: Optional[List[str]] = None) -> List[List[dict]]:
if system_messages:
assert len(prompts) == len(system_messages), f"the number of prompts ({len(prompts)}) is different from the number of system_messages ({len(system_messages)})"
else:
system_messages = [None] * len(prompts)
messages_list = []
for prompt, system_message in zip(prompts, system_messages):
messages = []
if system_message:
messages.append({"role": "system", "content": system_message})
messages.append({"role": "user", "content": prompt})
messages_list.append(messages)
return messages_list
def update_completion_params(self, params1: dict, params2: dict) -> dict:
config_params: list = self.config.get_config_params()
for key, value in params2.items():
if key in self._default_ignore_fields:
continue
if key not in config_params:
continue
params1[key] = value
return params1
def get_completion_params(self, **kwargs):
completion_params = self.config.get_set_params(ignore=self._default_ignore_fields)
completion_params = self.update_completion_params(completion_params, kwargs)
return completion_params
def get_stream_output(self, response: Stream, output_response: bool=True) -> str:
"""
Process stream response and return the complete output.
Args:
response: The stream response from OpenAI
output_response: Whether to print the response in real-time
Returns:
str: The complete output text
"""
output = ""
for chunk in response:
content = chunk.choices[0].delta.content
if content:
if output_response:
print(content, end="", flush=True)
output += content
if output_response:
print("")
return output
async def get_stream_output_async(self, response, output_response: bool = False) -> str:
"""
Process async stream response and return the complete output.
Args:
response (AsyncIterator[ChatCompletionChunk]): The async stream response from OpenAI
output_response (bool): Whether to print the response in real-time
Returns:
str: The complete output text
"""
output = ""
async for chunk in response:
content = chunk.choices[0].delta.content
if content:
if output_response:
print(content, end="", flush=True)
output += content
if output_response:
print("")
return output
def get_completion_output(self, response: ChatCompletion, output_response: bool=True) -> str:
output = response.choices[0].message.content
if output_response:
print(output)
return output
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(5))
def single_generate(self, messages: List[dict], **kwargs) -> str:
stream = kwargs["stream"] if "stream" in kwargs else self.config.stream
output_response = kwargs["output_response"] if "output_response" in kwargs else self.config.output_response
try:
completion_params = self.get_completion_params(**kwargs)
response = self._client.chat.completions.create(messages=messages, **completion_params)
if stream:
output = self.get_stream_output(response, output_response=output_response)
cost = self._stream_cost(messages=messages, output=output)
else:
output: str = self.get_completion_output(response=response, output_response=output_response)
cost = self._completion_cost(response) # calculate completion cost
self._update_cost(cost=cost)
except Exception as e:
raise RuntimeError(f"Error during single_generate of OpenAILLM: {str(e)}")
return output
def batch_generate(self, batch_messages: List[List[dict]], **kwargs) -> List[str]:
return [self.single_generate(messages=one_messages, **kwargs) for one_messages in batch_messages]
async def single_generate_async(self, messages: List[dict], **kwargs) -> str:
stream = kwargs.get("stream", self.config.stream)
output_response = kwargs.get("output_response", self.config.output_response)
try:
# Create a completely new client instance to avoid thread-local storage issues
# This is a more aggressive approach than using a lock
# isolated_client = OpenAI(api_key=self.config.openai_key)
isolated_client = self._init_client(self.config)
completion_params = self.get_completion_params(**kwargs)
# Use synchronous client in async context to avoid issues
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(
None,
lambda: isolated_client.chat.completions.create(
messages=messages,
**completion_params
)
)
if stream:
if hasattr(response, "__aiter__"):
output = await self.get_stream_output_async(response, output_response=output_response)
else:
output = self.get_stream_output(response, output_response=output_response)
cost = self._stream_cost(messages=messages, output=output)
else:
output: str = self.get_completion_output(response=response, output_response=output_response)
cost = self._completion_cost(response) # calculate completion cost
self._update_cost(cost=cost)
except Exception as e:
raise RuntimeError(f"Error during single_generate_async of OpenAILLM: {str(e)}")
return output
def _completion_cost(self, response: ChatCompletion) -> Cost:
input_tokens = response.usage.prompt_tokens
output_tokens = response.usage.completion_tokens
return self._compute_cost(input_tokens=input_tokens, output_tokens=output_tokens)
def _stream_cost(self, messages: List[dict], output: str) -> Cost:
model: str = self.config.model
input_tokens = token_counter(model=model, messages=messages)
output_tokens = token_counter(model=model, text=output)
return self._compute_cost(input_tokens=input_tokens, output_tokens=output_tokens)
def _compute_cost(self, input_tokens: int, output_tokens: int) -> Cost:
# use LiteLLM to compute cost, require the model name to be a valid model name in LiteLLM.
input_cost, output_cost = cost_per_token(
model=self.config.model,
prompt_tokens=input_tokens,
completion_tokens=output_tokens,
)
cost = Cost(input_tokens=input_tokens, output_tokens=output_tokens, input_cost=input_cost, output_cost=output_cost)
return cost
def _update_cost(self, cost: Cost):
cost_manager.update_cost(cost=cost, model=self.config.model)
|