import asyncio from tenacity import ( retry, stop_after_attempt, wait_random_exponential, ) from openai import OpenAI, Stream from openai.types.chat import ChatCompletion from typing import Optional, List from litellm import token_counter, cost_per_token from ..core.registry import register_model from .model_configs import OpenAILLMConfig from .base_model import BaseLLM from .model_utils import Cost, cost_manager, get_openai_model_cost @register_model(config_cls=OpenAILLMConfig, alias=["openai_llm"]) class OpenAILLM(BaseLLM): def init_model(self): config: OpenAILLMConfig = self.config self._client = self._init_client(config) # OpenAI(api_key=config.openai_key) self._default_ignore_fields = [ "llm_type", "output_response", "openai_key", "deepseek_key", "anthropic_key", "gemini_key", "meta_llama_key", "openrouter_key", "openrouter_base", "perplexity_key", "groq_key" ] # parameters in OpenAILLMConfig that are not OpenAI models' input parameters if self.config.model not in get_openai_model_cost(): raise KeyError(f"'{self.config.model}' is not a valid OpenAI model name!") def _init_client(self, config: OpenAILLMConfig): client = OpenAI(api_key=config.openai_key) return client def formulate_messages(self, prompts: List[str], system_messages: Optional[List[str]] = None) -> List[List[dict]]: if system_messages: assert len(prompts) == len(system_messages), f"the number of prompts ({len(prompts)}) is different from the number of system_messages ({len(system_messages)})" else: system_messages = [None] * len(prompts) messages_list = [] for prompt, system_message in zip(prompts, system_messages): messages = [] if system_message: messages.append({"role": "system", "content": system_message}) messages.append({"role": "user", "content": prompt}) messages_list.append(messages) return messages_list def update_completion_params(self, params1: dict, params2: dict) -> dict: config_params: list = self.config.get_config_params() for key, value in params2.items(): if key in self._default_ignore_fields: continue if key not in config_params: continue params1[key] = value return params1 def get_completion_params(self, **kwargs): completion_params = self.config.get_set_params(ignore=self._default_ignore_fields) completion_params = self.update_completion_params(completion_params, kwargs) return completion_params def get_stream_output(self, response: Stream, output_response: bool=True) -> str: """ Process stream response and return the complete output. Args: response: The stream response from OpenAI output_response: Whether to print the response in real-time Returns: str: The complete output text """ output = "" for chunk in response: content = chunk.choices[0].delta.content if content: if output_response: print(content, end="", flush=True) output += content if output_response: print("") return output async def get_stream_output_async(self, response, output_response: bool = False) -> str: """ Process async stream response and return the complete output. Args: response (AsyncIterator[ChatCompletionChunk]): The async stream response from OpenAI output_response (bool): Whether to print the response in real-time Returns: str: The complete output text """ output = "" async for chunk in response: content = chunk.choices[0].delta.content if content: if output_response: print(content, end="", flush=True) output += content if output_response: print("") return output def get_completion_output(self, response: ChatCompletion, output_response: bool=True) -> str: output = response.choices[0].message.content if output_response: print(output) return output @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(5)) def single_generate(self, messages: List[dict], **kwargs) -> str: stream = kwargs["stream"] if "stream" in kwargs else self.config.stream output_response = kwargs["output_response"] if "output_response" in kwargs else self.config.output_response try: completion_params = self.get_completion_params(**kwargs) response = self._client.chat.completions.create(messages=messages, **completion_params) if stream: output = self.get_stream_output(response, output_response=output_response) cost = self._stream_cost(messages=messages, output=output) else: output: str = self.get_completion_output(response=response, output_response=output_response) cost = self._completion_cost(response) # calculate completion cost self._update_cost(cost=cost) except Exception as e: raise RuntimeError(f"Error during single_generate of OpenAILLM: {str(e)}") return output def batch_generate(self, batch_messages: List[List[dict]], **kwargs) -> List[str]: return [self.single_generate(messages=one_messages, **kwargs) for one_messages in batch_messages] async def single_generate_async(self, messages: List[dict], **kwargs) -> str: stream = kwargs.get("stream", self.config.stream) output_response = kwargs.get("output_response", self.config.output_response) try: # Create a completely new client instance to avoid thread-local storage issues # This is a more aggressive approach than using a lock # isolated_client = OpenAI(api_key=self.config.openai_key) isolated_client = self._init_client(self.config) completion_params = self.get_completion_params(**kwargs) # Use synchronous client in async context to avoid issues loop = asyncio.get_event_loop() response = await loop.run_in_executor( None, lambda: isolated_client.chat.completions.create( messages=messages, **completion_params ) ) if stream: if hasattr(response, "__aiter__"): output = await self.get_stream_output_async(response, output_response=output_response) else: output = self.get_stream_output(response, output_response=output_response) cost = self._stream_cost(messages=messages, output=output) else: output: str = self.get_completion_output(response=response, output_response=output_response) cost = self._completion_cost(response) # calculate completion cost self._update_cost(cost=cost) except Exception as e: raise RuntimeError(f"Error during single_generate_async of OpenAILLM: {str(e)}") return output def _completion_cost(self, response: ChatCompletion) -> Cost: input_tokens = response.usage.prompt_tokens output_tokens = response.usage.completion_tokens return self._compute_cost(input_tokens=input_tokens, output_tokens=output_tokens) def _stream_cost(self, messages: List[dict], output: str) -> Cost: model: str = self.config.model input_tokens = token_counter(model=model, messages=messages) output_tokens = token_counter(model=model, text=output) return self._compute_cost(input_tokens=input_tokens, output_tokens=output_tokens) def _compute_cost(self, input_tokens: int, output_tokens: int) -> Cost: # use LiteLLM to compute cost, require the model name to be a valid model name in LiteLLM. input_cost, output_cost = cost_per_token( model=self.config.model, prompt_tokens=input_tokens, completion_tokens=output_tokens, ) cost = Cost(input_tokens=input_tokens, output_tokens=output_tokens, input_cost=input_cost, output_cost=output_cost) return cost def _update_cost(self, cost: Cost): cost_manager.update_cost(cost=cost, model=self.config.model)