EXAM_RAG_API / stores /llm /LLMProviderFactory.py
MinaNasser's picture
1st
1bc3f18
from .LLMEnums import LLMEnums
from stores.llm.providers.OpenAIProvider import OpenAIProvider
from stores.llm.providers.OllamaProvider import OllamaProvider
from stores.llm.providers.CohereProvider import CohereProvider
from stores.llm.providers.MistralProvider import MistralProvider
from stores.llm.providers.GroqProvider import GroqProvider
from stores.llm.providers.OpenRouterProvider import OpenRouterProvider
from stores.llm.providers.HuggingFaceProvider import HuggingFaceProvider
from stores.llm.providers.DeepSeekProvider import DeepSeekProvider
from stores.llm.providers.GeminiProvider import GeminiProvider
class LLMProviderFactory:
def __init__(self, config: dict):
self.config = config
def create(self, provider: str):
if provider == LLMEnums.OPENAI.value:
return OpenAIProvider(
api_key=self.config.OPENAI_API_KEY,
api_url=self.config.OPENAI_API_URL,
default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS,
default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS,
default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE,
)
if provider == LLMEnums.OLLAMA.value:
return OllamaProvider(
url=self.config.OLLAMA_URL,
api_key=self.config.OLLAMA_API_KEY,
default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS,
default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS,
default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE,
)
if provider == LLMEnums.COHERE.value:
return CohereProvider(
api_key=self.config.COHERE_API_KEY,
default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS,
default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS,
default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE,
)
if provider == LLMEnums.MISTRAL.value:
return MistralProvider(
api_key=self.config.MISTRAL_API_KEY,
default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS,
default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS,
default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE,
)
if provider == LLMEnums.GROQ.value:
return GroqProvider(
api_key=self.config.GROQ_API_KEY,
default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS,
default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS,
default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE,
)
if provider == LLMEnums.OPENROUTER.value:
return OpenRouterProvider(
api_key=self.config.OPENROUTER_API_KEY,
default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS,
default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS,
default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE,
)
if provider == LLMEnums.HUGGINGFACE.value:
return HuggingFaceProvider(
api_key=self.config.HF_API_KEY,
default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS,
default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS,
default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE,
)
if provider == LLMEnums.DEEPSEEK.value:
return DeepSeekProvider(
api_key=self.config.DEEPSEEK_API_KEY,
default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS,
default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS,
default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE,
)
if provider == LLMEnums.GEMINI.value:
return GeminiProvider(
api_key=self.config.GEMINI_API_KEY,
default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS,
default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS,
default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE,
)
return None