| import logging |
| from typing import Optional, List |
| from core.providers.base import LLMProvider |
| from core.providers.ollama import OllamaProvider |
| from core.providers.huggingface import HuggingFaceProvider |
| from core.providers.openai import OpenAIProvider |
| from utils.config import config |
|
|
| logger = logging.getLogger(__name__) |
|
|
| class ProviderNotAvailableError(Exception): |
| """Raised when no provider is available""" |
| pass |
|
|
| class LLMFactory: |
| """Factory for creating LLM providers with fallback support""" |
| |
| _instance = None |
| _providers = {} |
| |
| def __new__(cls): |
| if cls._instance is None: |
| cls._instance = super(LLMFactory, cls).__new__(cls) |
| cls._instance._initialized = False |
| return cls._instance |
| |
| def __init__(self): |
| if self._initialized: |
| return |
| |
| self._initialized = True |
| self._provider_chain = [] |
| self._circuit_breakers = {} |
| self._initialize_providers() |
| |
| def _initialize_providers(self): |
| """Initialize all available providers based on configuration""" |
| |
| provider_configs = [ |
| { |
| 'name': 'ollama', |
| 'class': OllamaProvider, |
| 'enabled': bool(config.ollama_host), |
| 'model': config.local_model_name |
| }, |
| { |
| 'name': 'huggingface', |
| 'class': HuggingFaceProvider, |
| 'enabled': bool(config.hf_token), |
| 'model': "meta-llama/Llama-2-7b-chat-hf" |
| }, |
| { |
| 'name': 'openai', |
| 'class': OpenAIProvider, |
| 'enabled': bool(config.openai_api_key), |
| 'model': "gpt-3.5-turbo" |
| } |
| ] |
| |
| |
| for provider_config in provider_configs: |
| if provider_config['enabled']: |
| try: |
| provider = provider_config['class']( |
| model_name=provider_config['model'] |
| ) |
| self._providers[provider_config['name']] = provider |
| self._provider_chain.append(provider_config['name']) |
| self._circuit_breakers[provider_config['name']] = { |
| 'failures': 0, |
| 'last_failure': None, |
| 'tripped': False |
| } |
| logger.info(f"Initialized {provider_config['name']} provider") |
| except Exception as e: |
| logger.warning(f"Failed to initialize {provider_config['name']} provider: {e}") |
| |
| def get_provider(self, preferred_provider: Optional[str] = None) -> LLMProvider: |
| """ |
| Get an LLM provider based on preference and availability |
| |
| Args: |
| preferred_provider: Preferred provider name (ollama, huggingface, openai) |
| |
| Returns: |
| LLMProvider instance |
| |
| Raises: |
| ProviderNotAvailableError: When no providers are available |
| """ |
| |
| if preferred_provider and preferred_provider in self._providers: |
| provider = self._providers[preferred_provider] |
| if self._is_provider_available(preferred_provider) and provider.validate_model(): |
| logger.info(f"Using preferred provider: {preferred_provider}") |
| return provider |
| |
| |
| for provider_name in self._provider_chain: |
| if self._is_provider_available(provider_name): |
| provider = self._providers[provider_name] |
| try: |
| if provider.validate_model(): |
| logger.info(f"Using fallback provider: {provider_name}") |
| return provider |
| except Exception as e: |
| logger.warning(f"Provider {provider_name} model validation failed: {e}") |
| self._record_provider_failure(provider_name) |
| |
| raise ProviderNotAvailableError("No LLM providers are available") |
| |
| def get_all_providers(self) -> List[LLMProvider]: |
| """Get all initialized providers""" |
| return list(self._providers.values()) |
| |
| def _is_provider_available(self, provider_name: str) -> bool: |
| """Check if a provider is available (not tripped by circuit breaker)""" |
| if provider_name not in self._circuit_breakers: |
| return False |
| |
| breaker = self._circuit_breakers[provider_name] |
| if not breaker['tripped']: |
| return True |
| |
| |
| |
| return False |
| |
| def _record_provider_failure(self, provider_name: str): |
| """Record a provider failure for circuit breaker logic""" |
| if provider_name in self._circuit_breakers: |
| breaker = self._circuit_breakers[provider_name] |
| breaker['failures'] += 1 |
| |
| if breaker['failures'] >= 3: |
| breaker['tripped'] = True |
| logger.warning(f"Circuit breaker tripped for provider: {provider_name}") |
|
|
| |
| llm_factory = LLMFactory() |
|
|