EXAM_RAG_API / stores /llm /providers /OpenAIProvider.py
MinaNasser's picture
1st
1bc3f18
from ..LLMInterface import LLMInterface
from ..LLMEnums import OpenAIEnums
from openai import OpenAI
import logging
class OpenAIProvider(LLMInterface):
def __init__(self, api_key: str, api_url: str = None,
default_input_max_characters: int = 1000,
default_generation_max_output_tokens: int = 1000,
default_generation_temperature: float = 0.1):
self.api_key = api_key
self.api_url = api_url
self.default_input_max_characters = default_input_max_characters
self.default_generation_max_output_tokens = default_generation_max_output_tokens
self.default_generation_temperature = default_generation_temperature
self.generation_model_id = None
self.embedding_model_id = None
self.embedding_size = None
self.client = OpenAI(api_key=self.api_key, base_url=self.api_url)
self.logger = logging.getLogger(__name__)
def set_generation_model(self, model_id: str):
self.generation_model_id = model_id
def set_embedding_model(self, model_id: str, embedding_size: int):
self.embedding_model_id = model_id
self.embedding_size = embedding_size
def process_text(self, text: str):
return text[:self.default_input_max_characters].strip()
def generate_text(self, prompt: str, chat_history: list = None,
max_output_tokens: int = None, temperature: float = None):
if not self.client:
self.logger.error("OpenAI client was not initialized")
return None
if not self.generation_model_id:
self.logger.error("OpenAI generation model not set")
return None
max_output_tokens = max_output_tokens or self.default_generation_max_output_tokens
temperature = temperature or self.default_generation_temperature
messages = chat_history[:] if chat_history else []
messages.append(self.construct_prompt(prompt, OpenAIEnums.USER.value))
try:
response = self.client.chat.completions.create(
model=self.generation_model_id,
messages=messages,
max_completion_tokens=max_output_tokens,
temperature=temperature
)
if (not response or not response.choices
or not response.choices[0].message
or not response.choices[0].message.content):
self.logger.error("Invalid OpenAI response format")
return None
return response.choices[0].message.content
except Exception as e:
self.logger.exception("Error while generating text with OpenAI: %s", e)
return None
def embed_text_batch(self, texts: list[str], batch_size: int = 32):
pass
def embed_text(self, text: str, document_type: str = None):
if not self.client:
self.logger.error("OpenAI client was not initialized")
return None
if not self.embedding_model_id:
self.logger.error("OpenAI embedding model not set")
return None
try:
response = self.client.embeddings.create(
model=self.embedding_model_id,
input=text
)
if not response or not response.data or not response.data[0].embedding:
self.logger.error("Invalid OpenAI embedding response")
return None
return response.data[0].embedding
except Exception as e:
self.logger.exception("Error while embedding text with OpenAI: %s", e)
return None
def construct_prompt(self, prompt: str, role: str):
return {
"role": role,
"content": self.process_text(prompt)
}