Spaces:
Runtime error
Runtime error
| """Model wrapper to interact with OpenAI models.""" | |
| import abc | |
| import ast | |
| from typing import Mapping | |
| import openai | |
| from src import parameters | |
| logger = parameters.LOGGER | |
| class xAIModel(abc.ABC): | |
| API_KEY = "" | |
| def __init__(self, model_name: str, API_KEY: str): | |
| try: | |
| self.client = openai.OpenAI( | |
| api_key=API_KEY, | |
| base_url="https://api.x.ai/v1", | |
| ) | |
| self.model_name = model_name | |
| except Exception as exc: | |
| raise Exception( | |
| "Failed to initialize Grok xAI model client. See traceback for more details.", | |
| ) from exc | |
| def prepare_input(self, prompt_dict: Mapping[str, str]) -> str: | |
| conversation = [] | |
| try: | |
| for role, content in prompt_dict.items(): | |
| conversation.append({"role": role, "content": content}) | |
| return conversation | |
| except Exception as exc: | |
| raise Exception( | |
| f"Incomplete Prompt Dictionary Passed. Expected to have atleast a role and it's content.\nPassed dict: {prompt_dict}", | |
| ) from exc | |
| def generate_response( | |
| self, | |
| prompt_dict: Mapping[str, str], | |
| max_output_tokens: int = None, | |
| temperature: int = 0.6, | |
| response_format: dict = None, | |
| ) -> str: | |
| conversation = self.prepare_input(prompt_dict) | |
| try: | |
| response = self.client.chat.completions.create( | |
| model=self.model_name, | |
| messages=conversation, | |
| max_tokens=max_output_tokens if max_output_tokens else None, | |
| temperature=temperature, | |
| response_format=response_format, | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as exc: | |
| raise Exception( | |
| f"Exception in generating model response.\nModel name: {self.model_name}\nInput prompt: {str(conversation)}", | |
| ) from exc | |
| def generate_valid_json_response( | |
| self, | |
| prompt_dict: Mapping[str, str], | |
| max_output_tokens: int = None, | |
| temperature: int = 0.6, | |
| ) -> str: | |
| """Generate a response with retries, returning a valid JSON.""" | |
| for _ in range(parameters.MAX_TRIES): | |
| try: | |
| model_response = self.generate_response( | |
| prompt_dict, max_output_tokens, temperature, {"type": "json_object"} | |
| ) | |
| return ast.literal_eval(model_response) | |
| except Exception as e: | |
| continue | |
| raise Exception( | |
| f"Maximum retries met before valid JSON structure was found.\nModel name: {self.model_name}\nInput prompt: {str(prompt_dict)}" | |
| ) | |
| GROK_2 = xAIModel("grok-2-1212", parameters.XAI_API_KEY) | |