DungeonMaster-AI / src /agents /llm_provider.py
bhupesh-sf's picture
first commit
f8ba6bf verified
"""
DungeonMaster AI - LLM Provider Chain
Manages LLM providers with automatic fallback from Gemini to OpenAI.
Includes circuit breaker pattern for reliability.
"""
from __future__ import annotations
import asyncio
import logging
import time
from datetime import datetime, timedelta
from enum import Enum
from llama_index.core.llms import ChatMessage, LLM
from llama_index.llms.gemini import Gemini
from llama_index.llms.openai import OpenAI
from src.config.settings import AppSettings, get_settings
from .exceptions import (
LLMAllProvidersFailedError,
LLMAuthenticationError,
LLMCircuitBreakerOpenError,
LLMQuotaExhaustedError,
LLMRateLimitError,
LLMTimeoutError,
)
from .models import LLMProviderHealth, LLMResponse
logger = logging.getLogger(__name__)
class CircuitState(str, Enum):
"""Circuit breaker states."""
CLOSED = "closed" # Normal operation
OPEN = "open" # Rejecting requests
HALF_OPEN = "half_open" # Testing recovery
class ProviderCircuitBreaker:
"""
Circuit breaker for an individual LLM provider.
Prevents cascading failures by temporarily blocking requests
to a failing provider.
"""
def __init__(
self,
provider_name: str,
failure_threshold: int = 3,
reset_timeout: float = 60.0,
) -> None:
self.provider_name = provider_name
self.failure_threshold = failure_threshold
self.reset_timeout = reset_timeout
self._state = CircuitState.CLOSED
self._failure_count = 0
self._last_failure_time: datetime | None = None
self._last_success_time: datetime | None = None
@property
def state(self) -> CircuitState:
"""Get current circuit state, checking for timeout transitions."""
if self._state == CircuitState.OPEN:
if self._should_attempt_reset():
self._state = CircuitState.HALF_OPEN
logger.info(
f"Circuit breaker for {self.provider_name} "
"transitioning to HALF_OPEN"
)
return self._state
@property
def is_available(self) -> bool:
"""Check if provider is available for requests."""
return self.state != CircuitState.OPEN
def _should_attempt_reset(self) -> bool:
"""Check if enough time has passed to attempt reset."""
if self._last_failure_time is None:
return True
elapsed = (datetime.now() - self._last_failure_time).total_seconds()
return elapsed >= self.reset_timeout
def record_success(self) -> None:
"""Record a successful request."""
self._failure_count = 0
self._last_success_time = datetime.now()
if self._state == CircuitState.HALF_OPEN:
self._state = CircuitState.CLOSED
logger.info(
f"Circuit breaker for {self.provider_name} "
"CLOSED after successful test"
)
def record_failure(self, error: Exception) -> None:
"""Record a failed request."""
self._failure_count += 1
self._last_failure_time = datetime.now()
logger.warning(
f"Provider {self.provider_name} failure "
f"({self._failure_count}/{self.failure_threshold}): {error}"
)
if self._state == CircuitState.HALF_OPEN:
# Test request failed, back to OPEN
self._state = CircuitState.OPEN
logger.warning(
f"Circuit breaker for {self.provider_name} "
"OPEN after failed test"
)
elif self._failure_count >= self.failure_threshold:
self._state = CircuitState.OPEN
logger.warning(
f"Circuit breaker for {self.provider_name} OPENED "
f"after {self._failure_count} failures"
)
def get_health(self) -> LLMProviderHealth:
"""Get health status for this provider."""
return LLMProviderHealth(
provider_name=self.provider_name,
is_available=self.is_available,
is_primary=False, # Set by parent
consecutive_failures=self._failure_count,
last_success=self._last_success_time,
last_error=None, # Could track this
circuit_open=self.state == CircuitState.OPEN,
)
class LLMFallbackChain:
"""
Manages LLM providers with automatic fallback.
Order: Gemini (primary) -> OpenAI (fallback) -> Error
Features:
- Circuit breaker per provider (3 failures, 60s reset)
- Configurable timeouts (Gemini: 30s, OpenAI: 45s)
- Automatic fallback on rate limits, timeouts, errors
- Provider health tracking
"""
# Timeout configuration per provider
PROVIDER_TIMEOUTS: dict[str, float] = {
"gemini": 30.0,
"openai": 45.0,
}
def __init__(
self,
settings: AppSettings | None = None,
gemini_timeout: float | None = None,
openai_timeout: float | None = None,
) -> None:
"""
Initialize the LLM fallback chain.
Args:
settings: Application settings. Defaults to get_settings().
gemini_timeout: Override Gemini timeout.
openai_timeout: Override OpenAI timeout.
"""
self._settings = settings or get_settings()
# Override timeouts if provided
if gemini_timeout:
self.PROVIDER_TIMEOUTS["gemini"] = gemini_timeout
if openai_timeout:
self.PROVIDER_TIMEOUTS["openai"] = openai_timeout
# Initialize providers (lazy)
self._gemini: Gemini | None = None
self._openai: OpenAI | None = None
# Circuit breakers
self._gemini_breaker = ProviderCircuitBreaker("gemini")
self._openai_breaker = ProviderCircuitBreaker("openai")
# Track which provider is currently active
self._current_provider: str | None = None
logger.debug("LLMFallbackChain initialized")
def _get_gemini(self) -> Gemini | None:
"""Get or create Gemini LLM instance."""
if self._gemini is not None:
return self._gemini
if not self._settings.llm.has_gemini:
logger.warning("Gemini API key not configured")
return None
try:
self._gemini = Gemini(
api_key=self._settings.llm.gemini_api_key,
model=self._settings.llm.gemini_model,
temperature=self._settings.llm.temperature,
max_tokens=self._settings.llm.max_tokens,
)
logger.info(f"Gemini LLM initialized: {self._settings.llm.gemini_model}")
return self._gemini
except Exception as e:
logger.error(f"Failed to initialize Gemini: {e}")
return None
def _get_openai(self) -> OpenAI | None:
"""Get or create OpenAI LLM instance."""
if self._openai is not None:
return self._openai
if not self._settings.llm.has_openai:
logger.warning("OpenAI API key not configured")
return None
try:
self._openai = OpenAI(
api_key=self._settings.llm.openai_api_key,
model=self._settings.llm.openai_model,
temperature=self._settings.llm.temperature,
max_tokens=self._settings.llm.max_tokens,
)
logger.info(f"OpenAI LLM initialized: {self._settings.llm.openai_model}")
return self._openai
except Exception as e:
logger.error(f"Failed to initialize OpenAI: {e}")
return None
def get_primary_llm(self) -> LLM | None:
"""
Get the primary LLM for use with LlamaIndex agents.
Returns Gemini if available, otherwise OpenAI.
"""
gemini = self._get_gemini()
if gemini and self._gemini_breaker.is_available:
return gemini
openai = self._get_openai()
if openai and self._openai_breaker.is_available:
return openai
return None
def get_fallback_llm(self) -> LLM | None:
"""
Get the fallback LLM for use when primary fails.
Returns OpenAI if Gemini is primary, otherwise None.
"""
# If Gemini is available, OpenAI is fallback
if self._get_gemini() and self._gemini_breaker.is_available:
openai = self._get_openai()
if openai and self._openai_breaker.is_available:
return openai
return None
async def generate(
self,
messages: list[ChatMessage],
timeout: float | None = None,
) -> LLMResponse:
"""
Generate a response using the LLM chain with fallback.
Args:
messages: Chat messages to send.
timeout: Optional timeout override.
Returns:
LLMResponse with generated text and metadata.
Raises:
LLMAllProvidersFailedError: If all providers fail.
"""
errors: dict[str, str] = {}
start_time = time.time()
# Try Gemini first
gemini = self._get_gemini()
if gemini and self._gemini_breaker.is_available:
try:
response = await self._call_provider(
provider_name="gemini",
llm=gemini,
messages=messages,
timeout=timeout or self.PROVIDER_TIMEOUTS["gemini"],
breaker=self._gemini_breaker,
)
response.latency_ms = (time.time() - start_time) * 1000
return response
except Exception as e:
errors["gemini"] = str(e)
logger.warning(f"Gemini failed, trying fallback: {e}")
elif not self._gemini_breaker.is_available:
errors["gemini"] = "Circuit breaker open"
# Try OpenAI as fallback
openai = self._get_openai()
if openai and self._openai_breaker.is_available:
try:
response = await self._call_provider(
provider_name="openai",
llm=openai,
messages=messages,
timeout=timeout or self.PROVIDER_TIMEOUTS["openai"],
breaker=self._openai_breaker,
)
response.from_fallback = True
response.latency_ms = (time.time() - start_time) * 1000
return response
except Exception as e:
errors["openai"] = str(e)
logger.error(f"OpenAI fallback also failed: {e}")
elif not self._openai_breaker.is_available:
errors["openai"] = "Circuit breaker open"
# All providers failed
raise LLMAllProvidersFailedError(errors)
async def _call_provider(
self,
provider_name: str,
llm: LLM,
messages: list[ChatMessage],
timeout: float,
breaker: ProviderCircuitBreaker,
) -> LLMResponse:
"""
Call a specific LLM provider with timeout and circuit breaker.
Args:
provider_name: Name of the provider.
llm: LLM instance to use.
messages: Messages to send.
timeout: Request timeout.
breaker: Circuit breaker for this provider.
Returns:
LLMResponse with generated text.
Raises:
Various LLM exceptions on failure.
"""
if not breaker.is_available:
raise LLMCircuitBreakerOpenError(provider_name)
try:
# Call with timeout
response = await asyncio.wait_for(
llm.achat(messages),
timeout=timeout,
)
# Extract text from response
text = response.message.content if response.message else ""
# Record success
breaker.record_success()
self._current_provider = provider_name
return LLMResponse(
text=str(text),
provider_used=provider_name,
model_used=getattr(llm, "model", "unknown"),
)
except asyncio.TimeoutError as e:
breaker.record_failure(e)
raise LLMTimeoutError(provider_name, timeout) from e
except Exception as e:
# Categorize the error
error_str = str(e).lower()
if "rate" in error_str and "limit" in error_str:
breaker.record_failure(e)
raise LLMRateLimitError(provider_name) from e
if "auth" in error_str or "api key" in error_str or "401" in error_str:
breaker.record_failure(e)
raise LLMAuthenticationError(provider_name) from e
if "quota" in error_str or "exceeded" in error_str:
breaker.record_failure(e)
raise LLMQuotaExhaustedError(provider_name) from e
# Generic error
breaker.record_failure(e)
raise
def get_health(self) -> dict[str, LLMProviderHealth]:
"""
Get health status of all providers.
Returns:
Dictionary mapping provider name to health status.
"""
health = {}
gemini_health = self._gemini_breaker.get_health()
gemini_health.is_primary = True
gemini_health.is_available = (
self._settings.llm.has_gemini
and self._gemini_breaker.is_available
)
health["gemini"] = gemini_health
openai_health = self._openai_breaker.get_health()
openai_health.is_primary = False
openai_health.is_available = (
self._settings.llm.has_openai
and self._openai_breaker.is_available
)
health["openai"] = openai_health
return health
@property
def current_provider(self) -> str | None:
"""Get the name of the most recently used provider."""
return self._current_provider
def reset_circuit_breakers(self) -> None:
"""Reset all circuit breakers to closed state."""
self._gemini_breaker = ProviderCircuitBreaker("gemini")
self._openai_breaker = ProviderCircuitBreaker("openai")
logger.info("All circuit breakers reset")