Spaces:
Sleeping
Sleeping
| """Model wrapper to interact with OpenAI models.""" | |
| import abc | |
| import ast | |
| from typing import Mapping | |
| import openai | |
| import parameters | |
| class OpenAIModel(abc.ABC): | |
| API_KEY = "" | |
| # TODO(Maani): Add support for more generation options like: | |
| # 1. temperature | |
| # 2. top-p | |
| # 3. stop sequences | |
| # 4. num_outputs | |
| # 5. response_format | |
| # 6. seed | |
| def __init__(self, model_name: str, API_KEY: str): | |
| try: | |
| self.client = openai.OpenAI( | |
| # This is the default and can be omitted | |
| api_key=API_KEY | |
| ) | |
| self.model_name = model_name | |
| except Exception as exc: | |
| raise Exception( | |
| f"Failed to initialize OpenAI model client. See traceback for more details.", | |
| ) from exc | |
| def prepare_input(self, prompt_dict: Mapping[str, str]) -> str: | |
| conversation = [] | |
| try: | |
| for role, content in prompt_dict.items(): | |
| conversation.append({"role": role, "content": content}) | |
| return conversation | |
| except Exception as exc: | |
| raise Exception( | |
| f"Incomplete Prompt Dictionary Passed. Expected to have atleast a role and it's content.\nPassed dict: {prompt_dict}", | |
| ) from exc | |
| def generate_response( | |
| self, | |
| prompt_dict: Mapping[str, str], | |
| max_output_tokens: int = None, | |
| temperature: int = 0.6, | |
| response_format: dict = None, | |
| ) -> str: | |
| conversation = self.prepare_input(prompt_dict) | |
| try: | |
| response = self.client.chat.completions.create( | |
| model=self.model_name, | |
| messages=conversation, | |
| max_tokens=max_output_tokens if max_output_tokens else None, | |
| temperature=temperature, | |
| response_format=response_format, | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as exc: | |
| raise Exception( | |
| f"Exception in generating model response.\nModel name: {self.model_name}\nInput prompt: {str(conversation)}", | |
| ) from exc | |
| def generate_valid_json_response( | |
| self, | |
| prompt_dict: Mapping[str, str], | |
| max_output_tokens: int = None, | |
| temperature: int = 0.7, | |
| ) -> str: | |
| """Generate a response with retries, returning a valid JSON.""" | |
| for _ in range(int(parameters.MAX_TRIES)): | |
| try: | |
| model_response = self.generate_response( | |
| prompt_dict, max_output_tokens, temperature, {"type": "json_object"} | |
| ) | |
| return ast.literal_eval(model_response) | |
| except Exception as e: | |
| continue | |
| raise Exception( | |
| f"Maximum retries met before valid JSON structure was found.\nModel name: {self.model_name}\nInput prompt: {str(prompt_dict)}" | |
| ) | |
| GPT_4O_MINI = OpenAIModel("gpt-4o-mini", parameters.OPEN_AI_API_KEY) | |