"""Model wrapper to interact with OpenAI models.""" import abc import ast from typing import Mapping import openai import parameters class OpenAIModel(abc.ABC): API_KEY = "" # TODO(Maani): Add support for more generation options like: # 1. temperature # 2. top-p # 3. stop sequences # 4. num_outputs # 5. response_format # 6. seed def __init__(self, model_name: str, API_KEY: str): try: self.client = openai.OpenAI( # This is the default and can be omitted api_key=API_KEY ) self.model_name = model_name except Exception as exc: raise Exception( f"Failed to initialize OpenAI model client. See traceback for more details.", ) from exc def prepare_input(self, prompt_dict: Mapping[str, str]) -> str: conversation = [] try: for role, content in prompt_dict.items(): conversation.append({"role": role, "content": content}) return conversation except Exception as exc: raise Exception( f"Incomplete Prompt Dictionary Passed. Expected to have atleast a role and it's content.\nPassed dict: {prompt_dict}", ) from exc def generate_response( self, prompt_dict: Mapping[str, str], max_output_tokens: int = None, temperature: int = 0.6, response_format: dict = None, ) -> str: conversation = self.prepare_input(prompt_dict) try: response = self.client.chat.completions.create( model=self.model_name, messages=conversation, max_tokens=max_output_tokens if max_output_tokens else None, temperature=temperature, response_format=response_format, ) return response.choices[0].message.content except Exception as exc: raise Exception( f"Exception in generating model response.\nModel name: {self.model_name}\nInput prompt: {str(conversation)}", ) from exc def generate_valid_json_response( self, prompt_dict: Mapping[str, str], max_output_tokens: int = None, temperature: int = 0.7, ) -> str: """Generate a response with retries, returning a valid JSON.""" for _ in range(int(parameters.MAX_TRIES)): try: model_response = self.generate_response( prompt_dict, max_output_tokens, temperature, {"type": "json_object"} ) return ast.literal_eval(model_response) except Exception as e: continue raise Exception( f"Maximum retries met before valid JSON structure was found.\nModel name: {self.model_name}\nInput prompt: {str(prompt_dict)}" ) GPT_4O_MINI = OpenAIModel("gpt-4o-mini", parameters.OPEN_AI_API_KEY)