Spaces:
Paused
Paused
| from .LLMEnums import LLMEnums | |
| from stores.llm.providers.OpenAIProvider import OpenAIProvider | |
| from stores.llm.providers.OllamaProvider import OllamaProvider | |
| from stores.llm.providers.CohereProvider import CohereProvider | |
| from stores.llm.providers.MistralProvider import MistralProvider | |
| from stores.llm.providers.GroqProvider import GroqProvider | |
| from stores.llm.providers.OpenRouterProvider import OpenRouterProvider | |
| from stores.llm.providers.HuggingFaceProvider import HuggingFaceProvider | |
| from stores.llm.providers.DeepSeekProvider import DeepSeekProvider | |
| from stores.llm.providers.GeminiProvider import GeminiProvider | |
| class LLMProviderFactory: | |
| def __init__(self, config: dict): | |
| self.config = config | |
| def create(self, provider: str): | |
| if provider == LLMEnums.OPENAI.value: | |
| return OpenAIProvider( | |
| api_key=self.config.OPENAI_API_KEY, | |
| api_url=self.config.OPENAI_API_URL, | |
| default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS, | |
| default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS, | |
| default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE, | |
| ) | |
| if provider == LLMEnums.OLLAMA.value: | |
| return OllamaProvider( | |
| url=self.config.OLLAMA_URL, | |
| api_key=self.config.OLLAMA_API_KEY, | |
| default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS, | |
| default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS, | |
| default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE, | |
| ) | |
| if provider == LLMEnums.COHERE.value: | |
| return CohereProvider( | |
| api_key=self.config.COHERE_API_KEY, | |
| default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS, | |
| default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS, | |
| default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE, | |
| ) | |
| if provider == LLMEnums.MISTRAL.value: | |
| return MistralProvider( | |
| api_key=self.config.MISTRAL_API_KEY, | |
| default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS, | |
| default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS, | |
| default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE, | |
| ) | |
| if provider == LLMEnums.GROQ.value: | |
| return GroqProvider( | |
| api_key=self.config.GROQ_API_KEY, | |
| default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS, | |
| default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS, | |
| default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE, | |
| ) | |
| if provider == LLMEnums.OPENROUTER.value: | |
| return OpenRouterProvider( | |
| api_key=self.config.OPENROUTER_API_KEY, | |
| default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS, | |
| default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS, | |
| default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE, | |
| ) | |
| if provider == LLMEnums.HUGGINGFACE.value: | |
| return HuggingFaceProvider( | |
| api_key=self.config.HF_API_KEY, | |
| default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS, | |
| default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS, | |
| default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE, | |
| ) | |
| if provider == LLMEnums.DEEPSEEK.value: | |
| return DeepSeekProvider( | |
| api_key=self.config.DEEPSEEK_API_KEY, | |
| default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS, | |
| default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS, | |
| default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE, | |
| ) | |
| if provider == LLMEnums.GEMINI.value: | |
| return GeminiProvider( | |
| api_key=self.config.GEMINI_API_KEY, | |
| default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS, | |
| default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS, | |
| default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE, | |
| ) | |
| return None | |