| """ |
| DungeonMaster AI - LLM Provider Chain |
| |
| Manages LLM providers with automatic fallback from Gemini to OpenAI. |
| Includes circuit breaker pattern for reliability. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import asyncio |
| import logging |
| import time |
| from datetime import datetime, timedelta |
| from enum import Enum |
|
|
| from llama_index.core.llms import ChatMessage, LLM |
| from llama_index.llms.gemini import Gemini |
| from llama_index.llms.openai import OpenAI |
|
|
| from src.config.settings import AppSettings, get_settings |
|
|
| from .exceptions import ( |
| LLMAllProvidersFailedError, |
| LLMAuthenticationError, |
| LLMCircuitBreakerOpenError, |
| LLMQuotaExhaustedError, |
| LLMRateLimitError, |
| LLMTimeoutError, |
| ) |
| from .models import LLMProviderHealth, LLMResponse |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class CircuitState(str, Enum): |
| """Circuit breaker states.""" |
|
|
| CLOSED = "closed" |
| OPEN = "open" |
| HALF_OPEN = "half_open" |
|
|
|
|
| class ProviderCircuitBreaker: |
| """ |
| Circuit breaker for an individual LLM provider. |
| |
| Prevents cascading failures by temporarily blocking requests |
| to a failing provider. |
| """ |
|
|
| def __init__( |
| self, |
| provider_name: str, |
| failure_threshold: int = 3, |
| reset_timeout: float = 60.0, |
| ) -> None: |
| self.provider_name = provider_name |
| self.failure_threshold = failure_threshold |
| self.reset_timeout = reset_timeout |
|
|
| self._state = CircuitState.CLOSED |
| self._failure_count = 0 |
| self._last_failure_time: datetime | None = None |
| self._last_success_time: datetime | None = None |
|
|
| @property |
| def state(self) -> CircuitState: |
| """Get current circuit state, checking for timeout transitions.""" |
| if self._state == CircuitState.OPEN: |
| if self._should_attempt_reset(): |
| self._state = CircuitState.HALF_OPEN |
| logger.info( |
| f"Circuit breaker for {self.provider_name} " |
| "transitioning to HALF_OPEN" |
| ) |
| return self._state |
|
|
| @property |
| def is_available(self) -> bool: |
| """Check if provider is available for requests.""" |
| return self.state != CircuitState.OPEN |
|
|
| def _should_attempt_reset(self) -> bool: |
| """Check if enough time has passed to attempt reset.""" |
| if self._last_failure_time is None: |
| return True |
| elapsed = (datetime.now() - self._last_failure_time).total_seconds() |
| return elapsed >= self.reset_timeout |
|
|
| def record_success(self) -> None: |
| """Record a successful request.""" |
| self._failure_count = 0 |
| self._last_success_time = datetime.now() |
|
|
| if self._state == CircuitState.HALF_OPEN: |
| self._state = CircuitState.CLOSED |
| logger.info( |
| f"Circuit breaker for {self.provider_name} " |
| "CLOSED after successful test" |
| ) |
|
|
| def record_failure(self, error: Exception) -> None: |
| """Record a failed request.""" |
| self._failure_count += 1 |
| self._last_failure_time = datetime.now() |
|
|
| logger.warning( |
| f"Provider {self.provider_name} failure " |
| f"({self._failure_count}/{self.failure_threshold}): {error}" |
| ) |
|
|
| if self._state == CircuitState.HALF_OPEN: |
| |
| self._state = CircuitState.OPEN |
| logger.warning( |
| f"Circuit breaker for {self.provider_name} " |
| "OPEN after failed test" |
| ) |
| elif self._failure_count >= self.failure_threshold: |
| self._state = CircuitState.OPEN |
| logger.warning( |
| f"Circuit breaker for {self.provider_name} OPENED " |
| f"after {self._failure_count} failures" |
| ) |
|
|
| def get_health(self) -> LLMProviderHealth: |
| """Get health status for this provider.""" |
| return LLMProviderHealth( |
| provider_name=self.provider_name, |
| is_available=self.is_available, |
| is_primary=False, |
| consecutive_failures=self._failure_count, |
| last_success=self._last_success_time, |
| last_error=None, |
| circuit_open=self.state == CircuitState.OPEN, |
| ) |
|
|
|
|
| class LLMFallbackChain: |
| """ |
| Manages LLM providers with automatic fallback. |
| |
| Order: Gemini (primary) -> OpenAI (fallback) -> Error |
| |
| Features: |
| - Circuit breaker per provider (3 failures, 60s reset) |
| - Configurable timeouts (Gemini: 30s, OpenAI: 45s) |
| - Automatic fallback on rate limits, timeouts, errors |
| - Provider health tracking |
| """ |
|
|
| |
| PROVIDER_TIMEOUTS: dict[str, float] = { |
| "gemini": 30.0, |
| "openai": 45.0, |
| } |
|
|
| def __init__( |
| self, |
| settings: AppSettings | None = None, |
| gemini_timeout: float | None = None, |
| openai_timeout: float | None = None, |
| ) -> None: |
| """ |
| Initialize the LLM fallback chain. |
| |
| Args: |
| settings: Application settings. Defaults to get_settings(). |
| gemini_timeout: Override Gemini timeout. |
| openai_timeout: Override OpenAI timeout. |
| """ |
| self._settings = settings or get_settings() |
|
|
| |
| if gemini_timeout: |
| self.PROVIDER_TIMEOUTS["gemini"] = gemini_timeout |
| if openai_timeout: |
| self.PROVIDER_TIMEOUTS["openai"] = openai_timeout |
|
|
| |
| self._gemini: Gemini | None = None |
| self._openai: OpenAI | None = None |
|
|
| |
| self._gemini_breaker = ProviderCircuitBreaker("gemini") |
| self._openai_breaker = ProviderCircuitBreaker("openai") |
|
|
| |
| self._current_provider: str | None = None |
|
|
| logger.debug("LLMFallbackChain initialized") |
|
|
| def _get_gemini(self) -> Gemini | None: |
| """Get or create Gemini LLM instance.""" |
| if self._gemini is not None: |
| return self._gemini |
|
|
| if not self._settings.llm.has_gemini: |
| logger.warning("Gemini API key not configured") |
| return None |
|
|
| try: |
| self._gemini = Gemini( |
| api_key=self._settings.llm.gemini_api_key, |
| model=self._settings.llm.gemini_model, |
| temperature=self._settings.llm.temperature, |
| max_tokens=self._settings.llm.max_tokens, |
| ) |
| logger.info(f"Gemini LLM initialized: {self._settings.llm.gemini_model}") |
| return self._gemini |
| except Exception as e: |
| logger.error(f"Failed to initialize Gemini: {e}") |
| return None |
|
|
| def _get_openai(self) -> OpenAI | None: |
| """Get or create OpenAI LLM instance.""" |
| if self._openai is not None: |
| return self._openai |
|
|
| if not self._settings.llm.has_openai: |
| logger.warning("OpenAI API key not configured") |
| return None |
|
|
| try: |
| self._openai = OpenAI( |
| api_key=self._settings.llm.openai_api_key, |
| model=self._settings.llm.openai_model, |
| temperature=self._settings.llm.temperature, |
| max_tokens=self._settings.llm.max_tokens, |
| ) |
| logger.info(f"OpenAI LLM initialized: {self._settings.llm.openai_model}") |
| return self._openai |
| except Exception as e: |
| logger.error(f"Failed to initialize OpenAI: {e}") |
| return None |
|
|
| def get_primary_llm(self) -> LLM | None: |
| """ |
| Get the primary LLM for use with LlamaIndex agents. |
| |
| Returns Gemini if available, otherwise OpenAI. |
| """ |
| gemini = self._get_gemini() |
| if gemini and self._gemini_breaker.is_available: |
| return gemini |
|
|
| openai = self._get_openai() |
| if openai and self._openai_breaker.is_available: |
| return openai |
|
|
| return None |
|
|
| def get_fallback_llm(self) -> LLM | None: |
| """ |
| Get the fallback LLM for use when primary fails. |
| |
| Returns OpenAI if Gemini is primary, otherwise None. |
| """ |
| |
| if self._get_gemini() and self._gemini_breaker.is_available: |
| openai = self._get_openai() |
| if openai and self._openai_breaker.is_available: |
| return openai |
|
|
| return None |
|
|
| async def generate( |
| self, |
| messages: list[ChatMessage], |
| timeout: float | None = None, |
| ) -> LLMResponse: |
| """ |
| Generate a response using the LLM chain with fallback. |
| |
| Args: |
| messages: Chat messages to send. |
| timeout: Optional timeout override. |
| |
| Returns: |
| LLMResponse with generated text and metadata. |
| |
| Raises: |
| LLMAllProvidersFailedError: If all providers fail. |
| """ |
| errors: dict[str, str] = {} |
| start_time = time.time() |
|
|
| |
| gemini = self._get_gemini() |
| if gemini and self._gemini_breaker.is_available: |
| try: |
| response = await self._call_provider( |
| provider_name="gemini", |
| llm=gemini, |
| messages=messages, |
| timeout=timeout or self.PROVIDER_TIMEOUTS["gemini"], |
| breaker=self._gemini_breaker, |
| ) |
| response.latency_ms = (time.time() - start_time) * 1000 |
| return response |
| except Exception as e: |
| errors["gemini"] = str(e) |
| logger.warning(f"Gemini failed, trying fallback: {e}") |
| elif not self._gemini_breaker.is_available: |
| errors["gemini"] = "Circuit breaker open" |
|
|
| |
| openai = self._get_openai() |
| if openai and self._openai_breaker.is_available: |
| try: |
| response = await self._call_provider( |
| provider_name="openai", |
| llm=openai, |
| messages=messages, |
| timeout=timeout or self.PROVIDER_TIMEOUTS["openai"], |
| breaker=self._openai_breaker, |
| ) |
| response.from_fallback = True |
| response.latency_ms = (time.time() - start_time) * 1000 |
| return response |
| except Exception as e: |
| errors["openai"] = str(e) |
| logger.error(f"OpenAI fallback also failed: {e}") |
| elif not self._openai_breaker.is_available: |
| errors["openai"] = "Circuit breaker open" |
|
|
| |
| raise LLMAllProvidersFailedError(errors) |
|
|
| async def _call_provider( |
| self, |
| provider_name: str, |
| llm: LLM, |
| messages: list[ChatMessage], |
| timeout: float, |
| breaker: ProviderCircuitBreaker, |
| ) -> LLMResponse: |
| """ |
| Call a specific LLM provider with timeout and circuit breaker. |
| |
| Args: |
| provider_name: Name of the provider. |
| llm: LLM instance to use. |
| messages: Messages to send. |
| timeout: Request timeout. |
| breaker: Circuit breaker for this provider. |
| |
| Returns: |
| LLMResponse with generated text. |
| |
| Raises: |
| Various LLM exceptions on failure. |
| """ |
| if not breaker.is_available: |
| raise LLMCircuitBreakerOpenError(provider_name) |
|
|
| try: |
| |
| response = await asyncio.wait_for( |
| llm.achat(messages), |
| timeout=timeout, |
| ) |
|
|
| |
| text = response.message.content if response.message else "" |
|
|
| |
| breaker.record_success() |
| self._current_provider = provider_name |
|
|
| return LLMResponse( |
| text=str(text), |
| provider_used=provider_name, |
| model_used=getattr(llm, "model", "unknown"), |
| ) |
|
|
| except asyncio.TimeoutError as e: |
| breaker.record_failure(e) |
| raise LLMTimeoutError(provider_name, timeout) from e |
|
|
| except Exception as e: |
| |
| error_str = str(e).lower() |
|
|
| if "rate" in error_str and "limit" in error_str: |
| breaker.record_failure(e) |
| raise LLMRateLimitError(provider_name) from e |
|
|
| if "auth" in error_str or "api key" in error_str or "401" in error_str: |
| breaker.record_failure(e) |
| raise LLMAuthenticationError(provider_name) from e |
|
|
| if "quota" in error_str or "exceeded" in error_str: |
| breaker.record_failure(e) |
| raise LLMQuotaExhaustedError(provider_name) from e |
|
|
| |
| breaker.record_failure(e) |
| raise |
|
|
| def get_health(self) -> dict[str, LLMProviderHealth]: |
| """ |
| Get health status of all providers. |
| |
| Returns: |
| Dictionary mapping provider name to health status. |
| """ |
| health = {} |
|
|
| gemini_health = self._gemini_breaker.get_health() |
| gemini_health.is_primary = True |
| gemini_health.is_available = ( |
| self._settings.llm.has_gemini |
| and self._gemini_breaker.is_available |
| ) |
| health["gemini"] = gemini_health |
|
|
| openai_health = self._openai_breaker.get_health() |
| openai_health.is_primary = False |
| openai_health.is_available = ( |
| self._settings.llm.has_openai |
| and self._openai_breaker.is_available |
| ) |
| health["openai"] = openai_health |
|
|
| return health |
|
|
| @property |
| def current_provider(self) -> str | None: |
| """Get the name of the most recently used provider.""" |
| return self._current_provider |
|
|
| def reset_circuit_breakers(self) -> None: |
| """Reset all circuit breakers to closed state.""" |
| self._gemini_breaker = ProviderCircuitBreaker("gemini") |
| self._openai_breaker = ProviderCircuitBreaker("openai") |
| logger.info("All circuit breakers reset") |
|
|