Spaces:
Paused
Paused
| from ..LLMInterface import LLMInterface | |
| from ..LLMEnums import OpenAIEnums | |
| from openai import OpenAI | |
| import logging | |
| class OpenAIProvider(LLMInterface): | |
| def __init__(self, api_key: str, api_url: str = None, | |
| default_input_max_characters: int = 1000, | |
| default_generation_max_output_tokens: int = 1000, | |
| default_generation_temperature: float = 0.1): | |
| self.api_key = api_key | |
| self.api_url = api_url | |
| self.default_input_max_characters = default_input_max_characters | |
| self.default_generation_max_output_tokens = default_generation_max_output_tokens | |
| self.default_generation_temperature = default_generation_temperature | |
| self.generation_model_id = None | |
| self.embedding_model_id = None | |
| self.embedding_size = None | |
| self.client = OpenAI(api_key=self.api_key, base_url=self.api_url) | |
| self.logger = logging.getLogger(__name__) | |
| def set_generation_model(self, model_id: str): | |
| self.generation_model_id = model_id | |
| def set_embedding_model(self, model_id: str, embedding_size: int): | |
| self.embedding_model_id = model_id | |
| self.embedding_size = embedding_size | |
| def process_text(self, text: str): | |
| return text[:self.default_input_max_characters].strip() | |
| def generate_text(self, prompt: str, chat_history: list = None, | |
| max_output_tokens: int = None, temperature: float = None): | |
| if not self.client: | |
| self.logger.error("OpenAI client was not initialized") | |
| return None | |
| if not self.generation_model_id: | |
| self.logger.error("OpenAI generation model not set") | |
| return None | |
| max_output_tokens = max_output_tokens or self.default_generation_max_output_tokens | |
| temperature = temperature or self.default_generation_temperature | |
| messages = chat_history[:] if chat_history else [] | |
| messages.append(self.construct_prompt(prompt, OpenAIEnums.USER.value)) | |
| try: | |
| response = self.client.chat.completions.create( | |
| model=self.generation_model_id, | |
| messages=messages, | |
| max_completion_tokens=max_output_tokens, | |
| temperature=temperature | |
| ) | |
| if (not response or not response.choices | |
| or not response.choices[0].message | |
| or not response.choices[0].message.content): | |
| self.logger.error("Invalid OpenAI response format") | |
| return None | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| self.logger.exception("Error while generating text with OpenAI: %s", e) | |
| return None | |
| def embed_text_batch(self, texts: list[str], batch_size: int = 32): | |
| pass | |
| def embed_text(self, text: str, document_type: str = None): | |
| if not self.client: | |
| self.logger.error("OpenAI client was not initialized") | |
| return None | |
| if not self.embedding_model_id: | |
| self.logger.error("OpenAI embedding model not set") | |
| return None | |
| try: | |
| response = self.client.embeddings.create( | |
| model=self.embedding_model_id, | |
| input=text | |
| ) | |
| if not response or not response.data or not response.data[0].embedding: | |
| self.logger.error("Invalid OpenAI embedding response") | |
| return None | |
| return response.data[0].embedding | |
| except Exception as e: | |
| self.logger.exception("Error while embedding text with OpenAI: %s", e) | |
| return None | |
| def construct_prompt(self, prompt: str, role: str): | |
| return { | |
| "role": role, | |
| "content": self.process_text(prompt) | |
| } | |