| | import time |
| | import logging |
| | from abc import ABC, abstractmethod |
| | from typing import List, Dict, Optional, Union |
| | from utils.config import config |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | class LLMProvider(ABC): |
| | """Abstract base class for all LLM providers""" |
| | |
| | def __init__(self, model_name: str, timeout: int = 30, max_retries: int = 3): |
| | self.model_name = model_name |
| | self.timeout = timeout |
| | self.max_retries = max_retries |
| | self.is_available = True |
| | self.failure_count = 0 |
| | self.last_failure_time = None |
| | |
| | @abstractmethod |
| | def generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]: |
| | """Generate a response synchronously""" |
| | pass |
| | |
| | @abstractmethod |
| | def stream_generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[Union[str, List[str]]]: |
| | """Generate a response with streaming support""" |
| | pass |
| | |
| | @abstractmethod |
| | def validate_model(self) -> bool: |
| | """Validate if the model is available""" |
| | pass |
| | |
| | def _retry_with_backoff(self, func, *args, **kwargs): |
| | """Retry logic with exponential backoff and circuit breaker""" |
| | last_exception = None |
| | |
| | for attempt in range(self.max_retries): |
| | try: |
| | |
| | if self.failure_count > 5 and self.last_failure_time: |
| | time_since_failure = time.time() - self.last_failure_time |
| | if time_since_failure < 60: |
| | raise Exception("Circuit breaker tripped - too many recent failures") |
| | |
| | result = func(*args, **kwargs) |
| | |
| | self.failure_count = 0 |
| | self.last_failure_time = None |
| | return result |
| | |
| | except Exception as e: |
| | last_exception = e |
| | self.failure_count += 1 |
| | self.last_failure_time = time.time() |
| | |
| | if attempt < self.max_retries - 1: |
| | sleep_time = min((2 ** attempt) * 1.0, 10.0) |
| | logger.warning(f"Attempt {attempt + 1} failed: {str(e)}. Retrying in {sleep_time}s...") |
| | time.sleep(sleep_time) |
| | else: |
| | logger.error(f"All {self.max_retries} attempts failed. Last error: {str(e)}") |
| | |
| | raise last_exception |
| | |
| | def _classify_error(self, error: Exception) -> str: |
| | """Classify error type for better handling""" |
| | error_str = str(error).lower() |
| | |
| | |
| | if any(term in error_str for term in ['connection', 'timeout', 'resolve', 'unreachable']): |
| | return 'network' |
| | |
| | |
| | if any(term in error_str for term in ['auth', 'unauthorized', 'invalid token', '401', '403']): |
| | return 'authentication' |
| | |
| | |
| | if any(term in error_str for term in ['rate limit', 'too many requests', 'quota exceeded', '429']): |
| | return 'rate_limit' |
| | |
| | |
| | if any(term in error_str for term in ['500', '502', '503', 'server error']): |
| | return 'server' |
| | |
| | return 'other' |
| | |
| | def _is_recoverable_error(self, error: Exception) -> bool: |
| | """Determine if error is likely recoverable""" |
| | error_type = self._classify_error(error) |
| | return error_type in ['network', 'rate_limit', 'server'] |
| |
|