Coverage for tinytroupe / openai_utils.py: 30%
198 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-28 17:48 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-28 17:48 +0000
1import os
2import openai
3from openai import OpenAI, AzureOpenAI
4import time
5import pickle
6import logging
7import configparser
8from typing import Union
11import tiktoken
12from tinytroupe import utils
13from tinytroupe.control import transactional
14from tinytroupe import default
15from tinytroupe import config_manager
17logger = logging.getLogger("tinytroupe")
19# We'll use various configuration elements below
20config = utils.read_config_file()
22###########################################################################
23# Client class
24###########################################################################
26class OpenAIClient:
27 """
28 A utility class for interacting with the OpenAI API.
29 """
31 def __init__(self, cache_api_calls=default["cache_api_calls"], cache_file_name=default["cache_file_name"]) -> None:
32 logger.debug("Initializing OpenAIClient")
34 # should we cache api calls and reuse them?
35 self.set_api_cache(cache_api_calls, cache_file_name)
37 def set_api_cache(self, cache_api_calls, cache_file_name=default["cache_file_name"]):
38 """
39 Enables or disables the caching of API calls.
41 Args:
42 cache_file_name (str): The name of the file to use for caching API calls.
43 """
44 self.cache_api_calls = cache_api_calls
45 self.cache_file_name = cache_file_name
46 if self.cache_api_calls:
47 # load the cache, if any
48 self.api_cache = self._load_cache()
51 def _setup_from_config(self):
52 """
53 Sets up the OpenAI API configurations for this client.
54 """
55 self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
57 @config_manager.config_defaults(
58 model="model",
59 temperature="temperature",
60 max_tokens="max_tokens",
61 top_p="top_p",
62 frequency_penalty="frequency_penalty",
63 presence_penalty="presence_penalty",
64 timeout="timeout",
65 max_attempts="max_attempts",
66 waiting_time="waiting_time",
67 exponential_backoff_factor="exponential_backoff_factor",
68 response_format=None,
69 echo=None
70 )
71 def send_message(self,
72 current_messages,
73 dedent_messages=True,
74 model=None,
75 temperature=None,
76 max_tokens=None,
77 top_p=None,
78 frequency_penalty=None,
79 presence_penalty=None,
80 stop=[],
81 timeout=None,
82 max_attempts=None,
83 waiting_time=None,
84 exponential_backoff_factor=None,
85 n = 1,
86 response_format=None,
87 enable_pydantic_model_return=False,
88 echo=False):
89 """
90 Sends a message to the OpenAI API and returns the response.
92 Args:
93 current_messages (list): A list of dictionaries representing the conversation history.
94 dedent_messages (bool): Whether to dedent the messages before sending them to the API.
95 model (str): The ID of the model to use for generating the response.
96 temperature (float): Controls the "creativity" of the response. Higher values result in more diverse responses.
97 max_tokens (int): The maximum number of tokens (words or punctuation marks) to generate in the response.
98 top_p (float): Controls the "quality" of the response. Higher values result in more coherent responses.
99 frequency_penalty (float): Controls the "repetition" of the response. Higher values result in less repetition.
100 presence_penalty (float): Controls the "diversity" of the response. Higher values result in more diverse responses.
101 stop (str): A string that, if encountered in the generated response, will cause the generation to stop.
102 max_attempts (int): The maximum number of attempts to make before giving up on generating a response.
103 timeout (int): The maximum number of seconds to wait for a response from the API.
104 waiting_time (int): The number of seconds to wait between requests.
105 exponential_backoff_factor (int): The factor by which to increase the waiting time between requests.
106 n (int): The number of completions to generate.
107 response_format: The format of the response, if any.
108 echo (bool): Whether to echo the input message in the response.
109 enable_pydantic_model_return (bool): Whether to enable Pydantic model return instead of dict when possible.
111 Returns:
112 A dictionary representing the generated response.
113 """
115 def aux_exponential_backoff():
116 nonlocal waiting_time
118 # in case waiting time was initially set to 0
119 if waiting_time <= 0:
120 waiting_time = 2
122 logger.info(f"Request failed. Waiting {waiting_time} seconds between requests...")
123 time.sleep(waiting_time)
125 # exponential backoff
126 waiting_time = waiting_time * exponential_backoff_factor
128 # setup the OpenAI configurations for this client.
129 self._setup_from_config()
131 # dedent the messages (field 'content' only) if needed (using textwrap)
132 if dedent_messages:
133 for message in current_messages:
134 if "content" in message:
135 message["content"] = utils.dedent(message["content"])
138 # We need to adapt the parameters to the API type, so we create a dictionary with them first
139 chat_api_params = {
140 "model": model,
141 "messages": current_messages,
142 "temperature": temperature,
143 "max_tokens":max_tokens,
144 "frequency_penalty": frequency_penalty,
145 "presence_penalty": presence_penalty,
146 "stop": stop,
147 "timeout": timeout,
148 "stream": False,
149 "n": n,
150 }
152 if top_p is not None and top_p > 0:
153 chat_api_params["top_p"] = top_p
155 if response_format is not None:
156 chat_api_params["response_format"] = response_format
158 i = 0
159 while i < max_attempts:
160 try:
161 i += 1
163 try:
164 logger.debug(f"Sending messages to OpenAI API. Token count={self._count_tokens(current_messages, model)}.")
165 except NotImplementedError:
166 logger.debug(f"Token count not implemented for model {model}.")
168 start_time = time.monotonic()
169 logger.debug(f"Calling model with client class {self.__class__.__name__}.")
171 ###############################################################
172 # call the model, either from the cache or from the API
173 ###############################################################
174 cache_key = str((model, chat_api_params)) # need string to be hashable
175 if self.cache_api_calls and (cache_key in self.api_cache):
176 response = self.api_cache[cache_key]
177 else:
178 if waiting_time > 0:
179 logger.info(f"Waiting {waiting_time} seconds before next API request (to avoid throttling)...")
180 time.sleep(waiting_time)
182 response = self._raw_model_call(model, chat_api_params)
183 if self.cache_api_calls:
184 self.api_cache[cache_key] = response
185 self._save_cache()
188 logger.debug(f"Got response from API: {response}")
189 end_time = time.monotonic()
190 logger.debug(
191 f"Got response in {end_time - start_time:.2f} seconds after {i} attempts.")
193 if enable_pydantic_model_return:
194 return utils.to_pydantic_or_sanitized_dict(self._raw_model_response_extractor(response), model=response_format)
195 else:
196 return utils.sanitize_dict(self._raw_model_response_extractor(response))
198 except InvalidRequestError as e:
199 logger.error(f"[{i}] Invalid request error, won't retry: {e}")
201 # there's no point in retrying if the request is invalid
202 # so we return None right away
203 return None
205 except openai.BadRequestError as e:
206 logger.error(f"[{i}] Invalid request error, won't retry: {e}")
208 # there's no point in retrying if the request is invalid
209 # so we return None right away
210 return None
212 except openai.RateLimitError:
213 logger.warning(
214 f"[{i}] Rate limit error, waiting a bit and trying again.")
215 aux_exponential_backoff()
217 except NonTerminalError as e:
218 logger.error(f"[{i}] Non-terminal error: {e}")
219 aux_exponential_backoff()
221 except Exception as e:
222 logger.error(f"[{i}] {type(e).__name__} Error: {e}")
223 aux_exponential_backoff()
225 logger.error(f"Failed to get response after {max_attempts} attempts.")
226 return None
228 def _raw_model_call(self, model, chat_api_params):
229 """
230 Calls the OpenAI API with the given parameters. Subclasses should
231 override this method to implement their own API calls.
232 """
234 # adjust parameters depending on the model
235 if self._is_reasoning_model(model):
236 # Reasoning models have slightly different parameters
237 del chat_api_params["stream"]
238 del chat_api_params["temperature"]
239 del chat_api_params["top_p"]
240 del chat_api_params["frequency_penalty"]
241 del chat_api_params["presence_penalty"]
243 chat_api_params["max_completion_tokens"] = chat_api_params["max_tokens"]
244 del chat_api_params["max_tokens"]
246 chat_api_params["reasoning_effort"] = default["reasoning_effort"]
249 # To make the log cleaner, we remove the messages from the logged parameters
250 logged_params = {k: v for k, v in chat_api_params.items() if k != "messages"}
252 if "response_format" in chat_api_params:
253 # to enforce the response format via pydantic, we need to use a different method
255 if "stream" in chat_api_params:
256 del chat_api_params["stream"]
258 logger.debug(f"Calling LLM model (using .parse too) with these parameters: {logged_params}. Not showing 'messages' parameter.")
259 # complete message
260 logger.debug(f" --> Complete messages sent to LLM: {chat_api_params['messages']}")
262 result_message = self.client.beta.chat.completions.parse(
263 **chat_api_params
264 )
266 return result_message
268 else:
269 logger.debug(f"Calling LLM model with these parameters: {logged_params}. Not showing 'messages' parameter.")
270 return self.client.chat.completions.create(
271 **chat_api_params
272 )
274 def _is_reasoning_model(self, model):
275 return "o1" in model or "o3" in model
277 def _raw_model_response_extractor(self, response):
278 """
279 Extracts the response from the API response. Subclasses should
280 override this method to implement their own response extraction.
281 """
282 return response.choices[0].message.to_dict()
284 def _count_tokens(self, messages: list, model: str):
285 """
286 Count the number of OpenAI tokens in a list of messages using tiktoken.
288 Adapted from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
290 Args:
291 messages (list): A list of dictionaries representing the conversation history.
292 model (str): The name of the model to use for encoding the string.
293 """
294 try:
295 try:
296 encoding = tiktoken.encoding_for_model(model)
297 except KeyError:
298 logger.debug("Token count: model not found. Using cl100k_base encoding.")
299 encoding = tiktoken.get_encoding("cl100k_base")
301 if model in {
302 "gpt-3.5-turbo-0613",
303 "gpt-3.5-turbo-16k-0613",
304 "gpt-4-0314",
305 "gpt-4-32k-0314",
306 "gpt-4-0613",
307 "gpt-4-32k-0613",
308 } or "o1" in model or "o3" in model: # assuming o1/3 models work the same way
309 tokens_per_message = 3
310 tokens_per_name = 1
311 elif model == "gpt-3.5-turbo-0301":
312 tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
313 tokens_per_name = -1 # if there's a name, the role is omitted
314 elif "gpt-3.5-turbo" in model:
315 logger.debug("Token count: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
316 return self._count_tokens(messages, model="gpt-3.5-turbo-0613")
317 elif ("gpt-4" in model) or ("ppo" in model) or ("alias-large" in model):
318 logger.debug("Token count: gpt-4/alias-large may update over time. Returning num tokens assuming gpt-4-0613.")
319 return self._count_tokens(messages, model="gpt-4-0613")
320 else:
321 raise NotImplementedError(
322 f"""_count_tokens() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
323 )
325 num_tokens = 0
326 for message in messages:
327 num_tokens += tokens_per_message
328 for key, value in message.items():
329 num_tokens += len(encoding.encode(value))
330 if key == "name":
331 num_tokens += tokens_per_name
332 num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
333 return num_tokens
335 except Exception as e:
336 logger.error(f"Error counting tokens: {e}")
337 return None
339 def _save_cache(self):
340 """
341 Saves the API cache to disk. We use pickle to do that because some obj
342 are not JSON serializable.
343 """
344 # use pickle to save the cache
345 pickle.dump(self.api_cache, open(self.cache_file_name, "wb", encoding="utf-8", errors="replace"))
348 def _load_cache(self):
350 """
351 Loads the API cache from disk.
352 """
353 # unpickle
354 return pickle.load(open(self.cache_file_name, "rb", encoding="utf-8", errors="replace")) if os.path.exists(self.cache_file_name) else {}
356 def get_embedding(self, text, model=default["embedding_model"]):
357 """
358 Gets the embedding of the given text using the specified model.
360 Args:
361 text (str): The text to embed.
362 model (str): The name of the model to use for embedding the text.
364 Returns:
365 The embedding of the text.
366 """
367 response = self._raw_embedding_model_call(text, model)
368 return self._raw_embedding_model_response_extractor(response)
370 def _raw_embedding_model_call(self, text, model):
371 """
372 Calls the OpenAI API to get the embedding of the given text. Subclasses should
373 override this method to implement their own API calls.
374 """
375 return self.client.embeddings.create(
376 input=[text],
377 model=model
378 )
380 def _raw_embedding_model_response_extractor(self, response):
381 """
382 Extracts the embedding from the API response. Subclasses should
383 override this method to implement their own response extraction.
384 """
385 return response.data[0].embedding
387class AzureClient(OpenAIClient):
389 def __init__(self, cache_api_calls=default["cache_api_calls"], cache_file_name=default["cache_file_name"]) -> None:
390 logger.debug("Initializing AzureClient")
392 super().__init__(cache_api_calls, cache_file_name)
394 def _setup_from_config(self):
395 """
396 Sets up the Azure OpenAI Service API configurations for this client,
397 including the API endpoint and key.
398 """
399 if os.getenv("AZURE_OPENAI_KEY"):
400 logger.info("Using Azure OpenAI Service API with key.")
401 self.client = AzureOpenAI(azure_endpoint= os.getenv("AZURE_OPENAI_ENDPOINT"),
402 api_version = config["OpenAI"]["AZURE_API_VERSION"],
403 api_key = os.getenv("AZURE_OPENAI_KEY"))
404 else: # Use Entra ID Auth
405 logger.info("Using Azure OpenAI Service API with Entra ID Auth.")
406 from azure.identity import DefaultAzureCredential, get_bearer_token_provider
408 credential = DefaultAzureCredential()
409 token_provider = get_bearer_token_provider(credential, "https://cognitiveservices.azure.com/.default")
410 self.client = AzureOpenAI(
411 azure_endpoint= os.getenv("AZURE_OPENAI_ENDPOINT"),
412 api_version = config["OpenAI"]["AZURE_API_VERSION"],
413 azure_ad_token_provider=token_provider
414 )
417class HelmholtzBlabladorClient(OpenAIClient):
419 def __init__(self, cache_api_calls=default["cache_api_calls"], cache_file_name=default["cache_file_name"]) -> None:
420 logger.debug("Initializing HelmholtzBlabladorClient")
421 super().__init__(cache_api_calls, cache_file_name)
423 def _setup_from_config(self):
424 """
425 Sets up the Helmholtz Blablador API configurations for this client.
426 """
427 self.client = OpenAI(
428 base_url="https://api.helmholtz-blablador.fz-juelich.de/v1",
429 api_key=os.getenv("BLABLADOR_API_KEY", "dummy"),
430 )
432###########################################################################
433# Exceptions
434###########################################################################
435class InvalidRequestError(Exception):
436 """
437 Exception raised when the request to the OpenAI API is invalid.
438 """
439 pass
441class NonTerminalError(Exception):
442 """
443 Exception raised when an unspecified error occurs but we know we can retry.
444 """
445 pass
447###########################################################################
448# Clients registry
449#
450# We can have potentially different clients, so we need a place to
451# register them and retrieve them when needed.
452#
453# We support both OpenAI and Azure OpenAI Service API by default.
454# Thus, we need to set the API parameters based on the choice of the user.
455# This is done within specialized classes.
456#
457# It is also possible to register custom clients, to access internal or
458# otherwise non-conventional API endpoints.
459###########################################################################
460_api_type_to_client = {}
461_api_type_override = None
463def register_client(api_type, client):
464 """
465 Registers a client for the given API type.
467 Args:
468 api_type (str): The API type for which we want to register the client.
469 client: The client to register.
470 """
471 _api_type_to_client[api_type] = client
473def _get_client_for_api_type(api_type):
474 """
475 Returns the client for the given API type.
477 Args:
478 api_type (str): The API type for which we want to get the client.
479 """
480 try:
481 return _api_type_to_client[api_type]
482 except KeyError:
483 raise ValueError(f"API type {api_type} is not supported. Please check the 'config.ini' file.")
485def client():
486 """
487 Returns the client for the configured API type.
488 """
489 api_type = config["OpenAI"]["API_TYPE"] if _api_type_override is None else _api_type_override
491 logger.debug(f"Using API type {api_type}.")
492 return _get_client_for_api_type(api_type)
495# TODO simplify the custom configuration methods below
497def force_api_type(api_type):
498 """
499 Forces the use of the given API type, thus overriding any other configuration.
501 Args:
502 api_type (str): The API type to use.
503 """
504 global _api_type_override
505 _api_type_override = api_type
507def force_api_cache(cache_api_calls, cache_file_name=default["cache_file_name"]):
508 """
509 Forces the use of the given API cache configuration, thus overriding any other configuration.
511 Args:
512 cache_api_calls (bool): Whether to cache API calls.
513 cache_file_name (str): The name of the file to use for caching API calls.
514 """
515 # set the cache parameters on all clients
516 for client in _api_type_to_client.values():
517 client.set_api_cache(cache_api_calls, cache_file_name)
519# default client
520register_client("openai", OpenAIClient())
521register_client("azure", AzureClient())
522register_client("helmholtz-blablador", HelmholtzBlabladorClient())