comic-grading / openai_wrapper.py
umang-immersfy's picture
regeneratio added, WIP
6e0fda9
raw
history blame
2.99 kB
"""Model wrapper to interact with OpenAI models."""
import abc
import ast
from typing import Mapping
import openai
import parameters
class OpenAIModel(abc.ABC):
API_KEY = ""
# TODO(Maani): Add support for more generation options like:
# 1. temperature
# 2. top-p
# 3. stop sequences
# 4. num_outputs
# 5. response_format
# 6. seed
def __init__(self, model_name: str, API_KEY: str):
try:
self.client = openai.OpenAI(
# This is the default and can be omitted
api_key=API_KEY
)
self.model_name = model_name
except Exception as exc:
raise Exception(
f"Failed to initialize OpenAI model client. See traceback for more details.",
) from exc
def prepare_input(self, prompt_dict: Mapping[str, str]) -> str:
conversation = []
try:
for role, content in prompt_dict.items():
conversation.append({"role": role, "content": content})
return conversation
except Exception as exc:
raise Exception(
f"Incomplete Prompt Dictionary Passed. Expected to have atleast a role and it's content.\nPassed dict: {prompt_dict}",
) from exc
def generate_response(
self,
prompt_dict: Mapping[str, str],
max_output_tokens: int = None,
temperature: int = 0.6,
response_format: dict = None,
) -> str:
conversation = self.prepare_input(prompt_dict)
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=conversation,
max_tokens=max_output_tokens if max_output_tokens else None,
temperature=temperature,
response_format=response_format,
)
return response.choices[0].message.content
except Exception as exc:
raise Exception(
f"Exception in generating model response.\nModel name: {self.model_name}\nInput prompt: {str(conversation)}",
) from exc
def generate_valid_json_response(
self,
prompt_dict: Mapping[str, str],
max_output_tokens: int = None,
temperature: int = 0.7,
) -> str:
"""Generate a response with retries, returning a valid JSON."""
for _ in range(int(parameters.MAX_TRIES)):
try:
model_response = self.generate_response(
prompt_dict, max_output_tokens, temperature, {"type": "json_object"}
)
return ast.literal_eval(model_response)
except Exception as e:
continue
raise Exception(
f"Maximum retries met before valid JSON structure was found.\nModel name: {self.model_name}\nInput prompt: {str(prompt_dict)}"
)
GPT_4O_MINI = OpenAIModel("gpt-4o-mini", parameters.OPEN_AI_API_KEY)