Coverage for tinytroupe / openai_utils.py: 30%

198 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-28 17:48 +0000

1import os 

2import openai 

3from openai import OpenAI, AzureOpenAI 

4import time 

5import pickle 

6import logging 

7import configparser 

8from typing import Union 

9 

10 

11import tiktoken 

12from tinytroupe import utils 

13from tinytroupe.control import transactional 

14from tinytroupe import default 

15from tinytroupe import config_manager 

16 

17logger = logging.getLogger("tinytroupe") 

18 

19# We'll use various configuration elements below 

20config = utils.read_config_file() 

21 

22########################################################################### 

23# Client class 

24########################################################################### 

25 

26class OpenAIClient: 

27 """ 

28 A utility class for interacting with the OpenAI API. 

29 """ 

30 

31 def __init__(self, cache_api_calls=default["cache_api_calls"], cache_file_name=default["cache_file_name"]) -> None: 

32 logger.debug("Initializing OpenAIClient") 

33 

34 # should we cache api calls and reuse them? 

35 self.set_api_cache(cache_api_calls, cache_file_name) 

36 

37 def set_api_cache(self, cache_api_calls, cache_file_name=default["cache_file_name"]): 

38 """ 

39 Enables or disables the caching of API calls. 

40 

41 Args: 

42 cache_file_name (str): The name of the file to use for caching API calls. 

43 """ 

44 self.cache_api_calls = cache_api_calls 

45 self.cache_file_name = cache_file_name 

46 if self.cache_api_calls: 

47 # load the cache, if any 

48 self.api_cache = self._load_cache() 

49 

50 

51 def _setup_from_config(self): 

52 """ 

53 Sets up the OpenAI API configurations for this client. 

54 """ 

55 self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) 

56 

57 @config_manager.config_defaults( 

58 model="model", 

59 temperature="temperature", 

60 max_tokens="max_tokens", 

61 top_p="top_p", 

62 frequency_penalty="frequency_penalty", 

63 presence_penalty="presence_penalty", 

64 timeout="timeout", 

65 max_attempts="max_attempts", 

66 waiting_time="waiting_time", 

67 exponential_backoff_factor="exponential_backoff_factor", 

68 response_format=None, 

69 echo=None 

70 ) 

71 def send_message(self, 

72 current_messages, 

73 dedent_messages=True, 

74 model=None, 

75 temperature=None, 

76 max_tokens=None, 

77 top_p=None, 

78 frequency_penalty=None, 

79 presence_penalty=None, 

80 stop=[], 

81 timeout=None, 

82 max_attempts=None, 

83 waiting_time=None, 

84 exponential_backoff_factor=None, 

85 n = 1, 

86 response_format=None, 

87 enable_pydantic_model_return=False, 

88 echo=False): 

89 """ 

90 Sends a message to the OpenAI API and returns the response. 

91 

92 Args: 

93 current_messages (list): A list of dictionaries representing the conversation history. 

94 dedent_messages (bool): Whether to dedent the messages before sending them to the API. 

95 model (str): The ID of the model to use for generating the response. 

96 temperature (float): Controls the "creativity" of the response. Higher values result in more diverse responses. 

97 max_tokens (int): The maximum number of tokens (words or punctuation marks) to generate in the response. 

98 top_p (float): Controls the "quality" of the response. Higher values result in more coherent responses. 

99 frequency_penalty (float): Controls the "repetition" of the response. Higher values result in less repetition. 

100 presence_penalty (float): Controls the "diversity" of the response. Higher values result in more diverse responses. 

101 stop (str): A string that, if encountered in the generated response, will cause the generation to stop. 

102 max_attempts (int): The maximum number of attempts to make before giving up on generating a response. 

103 timeout (int): The maximum number of seconds to wait for a response from the API. 

104 waiting_time (int): The number of seconds to wait between requests. 

105 exponential_backoff_factor (int): The factor by which to increase the waiting time between requests. 

106 n (int): The number of completions to generate. 

107 response_format: The format of the response, if any. 

108 echo (bool): Whether to echo the input message in the response. 

109 enable_pydantic_model_return (bool): Whether to enable Pydantic model return instead of dict when possible. 

110 

111 Returns: 

112 A dictionary representing the generated response. 

113 """ 

114 

115 def aux_exponential_backoff(): 

116 nonlocal waiting_time 

117 

118 # in case waiting time was initially set to 0 

119 if waiting_time <= 0: 

120 waiting_time = 2 

121 

122 logger.info(f"Request failed. Waiting {waiting_time} seconds between requests...") 

123 time.sleep(waiting_time) 

124 

125 # exponential backoff 

126 waiting_time = waiting_time * exponential_backoff_factor 

127 

128 # setup the OpenAI configurations for this client. 

129 self._setup_from_config() 

130 

131 # dedent the messages (field 'content' only) if needed (using textwrap) 

132 if dedent_messages: 

133 for message in current_messages: 

134 if "content" in message: 

135 message["content"] = utils.dedent(message["content"]) 

136 

137 

138 # We need to adapt the parameters to the API type, so we create a dictionary with them first 

139 chat_api_params = { 

140 "model": model, 

141 "messages": current_messages, 

142 "temperature": temperature, 

143 "max_tokens":max_tokens, 

144 "frequency_penalty": frequency_penalty, 

145 "presence_penalty": presence_penalty, 

146 "stop": stop, 

147 "timeout": timeout, 

148 "stream": False, 

149 "n": n, 

150 } 

151 

152 if top_p is not None and top_p > 0: 

153 chat_api_params["top_p"] = top_p 

154 

155 if response_format is not None: 

156 chat_api_params["response_format"] = response_format 

157 

158 i = 0 

159 while i < max_attempts: 

160 try: 

161 i += 1 

162 

163 try: 

164 logger.debug(f"Sending messages to OpenAI API. Token count={self._count_tokens(current_messages, model)}.") 

165 except NotImplementedError: 

166 logger.debug(f"Token count not implemented for model {model}.") 

167 

168 start_time = time.monotonic() 

169 logger.debug(f"Calling model with client class {self.__class__.__name__}.") 

170 

171 ############################################################### 

172 # call the model, either from the cache or from the API 

173 ############################################################### 

174 cache_key = str((model, chat_api_params)) # need string to be hashable 

175 if self.cache_api_calls and (cache_key in self.api_cache): 

176 response = self.api_cache[cache_key] 

177 else: 

178 if waiting_time > 0: 

179 logger.info(f"Waiting {waiting_time} seconds before next API request (to avoid throttling)...") 

180 time.sleep(waiting_time) 

181 

182 response = self._raw_model_call(model, chat_api_params) 

183 if self.cache_api_calls: 

184 self.api_cache[cache_key] = response 

185 self._save_cache() 

186 

187 

188 logger.debug(f"Got response from API: {response}") 

189 end_time = time.monotonic() 

190 logger.debug( 

191 f"Got response in {end_time - start_time:.2f} seconds after {i} attempts.") 

192 

193 if enable_pydantic_model_return: 

194 return utils.to_pydantic_or_sanitized_dict(self._raw_model_response_extractor(response), model=response_format) 

195 else: 

196 return utils.sanitize_dict(self._raw_model_response_extractor(response)) 

197 

198 except InvalidRequestError as e: 

199 logger.error(f"[{i}] Invalid request error, won't retry: {e}") 

200 

201 # there's no point in retrying if the request is invalid 

202 # so we return None right away 

203 return None 

204 

205 except openai.BadRequestError as e: 

206 logger.error(f"[{i}] Invalid request error, won't retry: {e}") 

207 

208 # there's no point in retrying if the request is invalid 

209 # so we return None right away 

210 return None 

211 

212 except openai.RateLimitError: 

213 logger.warning( 

214 f"[{i}] Rate limit error, waiting a bit and trying again.") 

215 aux_exponential_backoff() 

216 

217 except NonTerminalError as e: 

218 logger.error(f"[{i}] Non-terminal error: {e}") 

219 aux_exponential_backoff() 

220 

221 except Exception as e: 

222 logger.error(f"[{i}] {type(e).__name__} Error: {e}") 

223 aux_exponential_backoff() 

224 

225 logger.error(f"Failed to get response after {max_attempts} attempts.") 

226 return None 

227 

228 def _raw_model_call(self, model, chat_api_params): 

229 """ 

230 Calls the OpenAI API with the given parameters. Subclasses should 

231 override this method to implement their own API calls. 

232 """ 

233 

234 # adjust parameters depending on the model 

235 if self._is_reasoning_model(model): 

236 # Reasoning models have slightly different parameters 

237 del chat_api_params["stream"] 

238 del chat_api_params["temperature"] 

239 del chat_api_params["top_p"] 

240 del chat_api_params["frequency_penalty"] 

241 del chat_api_params["presence_penalty"] 

242 

243 chat_api_params["max_completion_tokens"] = chat_api_params["max_tokens"] 

244 del chat_api_params["max_tokens"] 

245 

246 chat_api_params["reasoning_effort"] = default["reasoning_effort"] 

247 

248 

249 # To make the log cleaner, we remove the messages from the logged parameters 

250 logged_params = {k: v for k, v in chat_api_params.items() if k != "messages"} 

251 

252 if "response_format" in chat_api_params: 

253 # to enforce the response format via pydantic, we need to use a different method 

254 

255 if "stream" in chat_api_params: 

256 del chat_api_params["stream"] 

257 

258 logger.debug(f"Calling LLM model (using .parse too) with these parameters: {logged_params}. Not showing 'messages' parameter.") 

259 # complete message 

260 logger.debug(f" --> Complete messages sent to LLM: {chat_api_params['messages']}") 

261 

262 result_message = self.client.beta.chat.completions.parse( 

263 **chat_api_params 

264 ) 

265 

266 return result_message 

267 

268 else: 

269 logger.debug(f"Calling LLM model with these parameters: {logged_params}. Not showing 'messages' parameter.") 

270 return self.client.chat.completions.create( 

271 **chat_api_params 

272 ) 

273 

274 def _is_reasoning_model(self, model): 

275 return "o1" in model or "o3" in model 

276 

277 def _raw_model_response_extractor(self, response): 

278 """ 

279 Extracts the response from the API response. Subclasses should 

280 override this method to implement their own response extraction. 

281 """ 

282 return response.choices[0].message.to_dict() 

283 

284 def _count_tokens(self, messages: list, model: str): 

285 """ 

286 Count the number of OpenAI tokens in a list of messages using tiktoken. 

287 

288 Adapted from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb 

289 

290 Args: 

291 messages (list): A list of dictionaries representing the conversation history. 

292 model (str): The name of the model to use for encoding the string. 

293 """ 

294 try: 

295 try: 

296 encoding = tiktoken.encoding_for_model(model) 

297 except KeyError: 

298 logger.debug("Token count: model not found. Using cl100k_base encoding.") 

299 encoding = tiktoken.get_encoding("cl100k_base") 

300 

301 if model in { 

302 "gpt-3.5-turbo-0613", 

303 "gpt-3.5-turbo-16k-0613", 

304 "gpt-4-0314", 

305 "gpt-4-32k-0314", 

306 "gpt-4-0613", 

307 "gpt-4-32k-0613", 

308 } or "o1" in model or "o3" in model: # assuming o1/3 models work the same way 

309 tokens_per_message = 3 

310 tokens_per_name = 1 

311 elif model == "gpt-3.5-turbo-0301": 

312 tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n 

313 tokens_per_name = -1 # if there's a name, the role is omitted 

314 elif "gpt-3.5-turbo" in model: 

315 logger.debug("Token count: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.") 

316 return self._count_tokens(messages, model="gpt-3.5-turbo-0613") 

317 elif ("gpt-4" in model) or ("ppo" in model) or ("alias-large" in model): 

318 logger.debug("Token count: gpt-4/alias-large may update over time. Returning num tokens assuming gpt-4-0613.") 

319 return self._count_tokens(messages, model="gpt-4-0613") 

320 else: 

321 raise NotImplementedError( 

322 f"""_count_tokens() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""" 

323 ) 

324 

325 num_tokens = 0 

326 for message in messages: 

327 num_tokens += tokens_per_message 

328 for key, value in message.items(): 

329 num_tokens += len(encoding.encode(value)) 

330 if key == "name": 

331 num_tokens += tokens_per_name 

332 num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> 

333 return num_tokens 

334 

335 except Exception as e: 

336 logger.error(f"Error counting tokens: {e}") 

337 return None 

338 

339 def _save_cache(self): 

340 """ 

341 Saves the API cache to disk. We use pickle to do that because some obj 

342 are not JSON serializable. 

343 """ 

344 # use pickle to save the cache 

345 pickle.dump(self.api_cache, open(self.cache_file_name, "wb", encoding="utf-8", errors="replace")) 

346 

347 

348 def _load_cache(self): 

349 

350 """ 

351 Loads the API cache from disk. 

352 """ 

353 # unpickle 

354 return pickle.load(open(self.cache_file_name, "rb", encoding="utf-8", errors="replace")) if os.path.exists(self.cache_file_name) else {} 

355 

356 def get_embedding(self, text, model=default["embedding_model"]): 

357 """ 

358 Gets the embedding of the given text using the specified model. 

359 

360 Args: 

361 text (str): The text to embed. 

362 model (str): The name of the model to use for embedding the text. 

363 

364 Returns: 

365 The embedding of the text. 

366 """ 

367 response = self._raw_embedding_model_call(text, model) 

368 return self._raw_embedding_model_response_extractor(response) 

369 

370 def _raw_embedding_model_call(self, text, model): 

371 """ 

372 Calls the OpenAI API to get the embedding of the given text. Subclasses should 

373 override this method to implement their own API calls. 

374 """ 

375 return self.client.embeddings.create( 

376 input=[text], 

377 model=model 

378 ) 

379 

380 def _raw_embedding_model_response_extractor(self, response): 

381 """ 

382 Extracts the embedding from the API response. Subclasses should 

383 override this method to implement their own response extraction. 

384 """ 

385 return response.data[0].embedding 

386 

387class AzureClient(OpenAIClient): 

388 

389 def __init__(self, cache_api_calls=default["cache_api_calls"], cache_file_name=default["cache_file_name"]) -> None: 

390 logger.debug("Initializing AzureClient") 

391 

392 super().__init__(cache_api_calls, cache_file_name) 

393 

394 def _setup_from_config(self): 

395 """ 

396 Sets up the Azure OpenAI Service API configurations for this client, 

397 including the API endpoint and key. 

398 """ 

399 if os.getenv("AZURE_OPENAI_KEY"): 

400 logger.info("Using Azure OpenAI Service API with key.") 

401 self.client = AzureOpenAI(azure_endpoint= os.getenv("AZURE_OPENAI_ENDPOINT"), 

402 api_version = config["OpenAI"]["AZURE_API_VERSION"], 

403 api_key = os.getenv("AZURE_OPENAI_KEY")) 

404 else: # Use Entra ID Auth 

405 logger.info("Using Azure OpenAI Service API with Entra ID Auth.") 

406 from azure.identity import DefaultAzureCredential, get_bearer_token_provider 

407 

408 credential = DefaultAzureCredential() 

409 token_provider = get_bearer_token_provider(credential, "https://cognitiveservices.azure.com/.default") 

410 self.client = AzureOpenAI( 

411 azure_endpoint= os.getenv("AZURE_OPENAI_ENDPOINT"), 

412 api_version = config["OpenAI"]["AZURE_API_VERSION"], 

413 azure_ad_token_provider=token_provider 

414 ) 

415 

416 

417class HelmholtzBlabladorClient(OpenAIClient): 

418 

419 def __init__(self, cache_api_calls=default["cache_api_calls"], cache_file_name=default["cache_file_name"]) -> None: 

420 logger.debug("Initializing HelmholtzBlabladorClient") 

421 super().__init__(cache_api_calls, cache_file_name) 

422 

423 def _setup_from_config(self): 

424 """ 

425 Sets up the Helmholtz Blablador API configurations for this client. 

426 """ 

427 self.client = OpenAI( 

428 base_url="https://api.helmholtz-blablador.fz-juelich.de/v1", 

429 api_key=os.getenv("BLABLADOR_API_KEY", "dummy"), 

430 ) 

431 

432########################################################################### 

433# Exceptions 

434########################################################################### 

435class InvalidRequestError(Exception): 

436 """ 

437 Exception raised when the request to the OpenAI API is invalid. 

438 """ 

439 pass 

440 

441class NonTerminalError(Exception): 

442 """ 

443 Exception raised when an unspecified error occurs but we know we can retry. 

444 """ 

445 pass 

446 

447########################################################################### 

448# Clients registry 

449# 

450# We can have potentially different clients, so we need a place to  

451# register them and retrieve them when needed. 

452# 

453# We support both OpenAI and Azure OpenAI Service API by default. 

454# Thus, we need to set the API parameters based on the choice of the user. 

455# This is done within specialized classes. 

456# 

457# It is also possible to register custom clients, to access internal or 

458# otherwise non-conventional API endpoints. 

459########################################################################### 

460_api_type_to_client = {} 

461_api_type_override = None 

462 

463def register_client(api_type, client): 

464 """ 

465 Registers a client for the given API type. 

466 

467 Args: 

468 api_type (str): The API type for which we want to register the client. 

469 client: The client to register. 

470 """ 

471 _api_type_to_client[api_type] = client 

472 

473def _get_client_for_api_type(api_type): 

474 """ 

475 Returns the client for the given API type. 

476 

477 Args: 

478 api_type (str): The API type for which we want to get the client. 

479 """ 

480 try: 

481 return _api_type_to_client[api_type] 

482 except KeyError: 

483 raise ValueError(f"API type {api_type} is not supported. Please check the 'config.ini' file.") 

484 

485def client(): 

486 """ 

487 Returns the client for the configured API type. 

488 """ 

489 api_type = config["OpenAI"]["API_TYPE"] if _api_type_override is None else _api_type_override 

490 

491 logger.debug(f"Using API type {api_type}.") 

492 return _get_client_for_api_type(api_type) 

493 

494 

495# TODO simplify the custom configuration methods below 

496 

497def force_api_type(api_type): 

498 """ 

499 Forces the use of the given API type, thus overriding any other configuration. 

500 

501 Args: 

502 api_type (str): The API type to use. 

503 """ 

504 global _api_type_override 

505 _api_type_override = api_type 

506 

507def force_api_cache(cache_api_calls, cache_file_name=default["cache_file_name"]): 

508 """ 

509 Forces the use of the given API cache configuration, thus overriding any other configuration. 

510 

511 Args: 

512 cache_api_calls (bool): Whether to cache API calls. 

513 cache_file_name (str): The name of the file to use for caching API calls. 

514 """ 

515 # set the cache parameters on all clients 

516 for client in _api_type_to_client.values(): 

517 client.set_api_cache(cache_api_calls, cache_file_name) 

518 

519# default client 

520register_client("openai", OpenAIClient()) 

521register_client("azure", AzureClient()) 

522register_client("helmholtz-blablador", HelmholtzBlabladorClient()) 

523 

524 

525