from .LLMEnums import LLMEnums from stores.llm.providers.OpenAIProvider import OpenAIProvider from stores.llm.providers.OllamaProvider import OllamaProvider from stores.llm.providers.CohereProvider import CohereProvider from stores.llm.providers.MistralProvider import MistralProvider from stores.llm.providers.GroqProvider import GroqProvider from stores.llm.providers.OpenRouterProvider import OpenRouterProvider from stores.llm.providers.HuggingFaceProvider import HuggingFaceProvider from stores.llm.providers.DeepSeekProvider import DeepSeekProvider from stores.llm.providers.GeminiProvider import GeminiProvider class LLMProviderFactory: def __init__(self, config: dict): self.config = config def create(self, provider: str): if provider == LLMEnums.OPENAI.value: return OpenAIProvider( api_key=self.config.OPENAI_API_KEY, api_url=self.config.OPENAI_API_URL, default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS, default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS, default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE, ) if provider == LLMEnums.OLLAMA.value: return OllamaProvider( url=self.config.OLLAMA_URL, api_key=self.config.OLLAMA_API_KEY, default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS, default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS, default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE, ) if provider == LLMEnums.COHERE.value: return CohereProvider( api_key=self.config.COHERE_API_KEY, default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS, default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS, default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE, ) if provider == LLMEnums.MISTRAL.value: return MistralProvider( api_key=self.config.MISTRAL_API_KEY, default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS, default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS, default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE, ) if provider == LLMEnums.GROQ.value: return GroqProvider( api_key=self.config.GROQ_API_KEY, default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS, default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS, default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE, ) if provider == LLMEnums.OPENROUTER.value: return OpenRouterProvider( api_key=self.config.OPENROUTER_API_KEY, default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS, default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS, default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE, ) if provider == LLMEnums.HUGGINGFACE.value: return HuggingFaceProvider( api_key=self.config.HF_API_KEY, default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS, default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS, default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE, ) if provider == LLMEnums.DEEPSEEK.value: return DeepSeekProvider( api_key=self.config.DEEPSEEK_API_KEY, default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS, default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS, default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE, ) if provider == LLMEnums.GEMINI.value: return GeminiProvider( api_key=self.config.GEMINI_API_KEY, default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS, default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS, default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE, ) return None