Spaces:
Running
Running
| """ | |
| LLM client orchestrator for multi-provider inference. | |
| This module provides a unified, scalable interface for interacting with | |
| multiple Large Language Model providers including: | |
| - Google Gemini | |
| - Anthropic Claude | |
| - OpenAI (GPT series) | |
| - DeepSeek | |
| - Qwen (Alibaba) | |
| - Kimi (Moonshot AI) | |
| - OpenRouter (aggregator) | |
| - Custom HTTP endpoints via generic client | |
| Performance Characteristics | |
| --------------------------- | |
| - Request latency: Provider-dependent (typically 200ms-5s for inference) | |
| - Retry overhead: O(log n) for exponential backoff with max_attempts=3-5 | |
| - Token counting: O(L) where L = character length via provider-specific tokenizer | |
| - Memory: O(1) per request + O(P) for provider configs, P = provider count | |
| Thread Safety | |
| ------------- | |
| - All client methods are async and reentrant | |
| - Rate limiters, circuit breakers, and concurrency semaphores use asyncio.Lock | |
| - No shared mutable state beyond configured provider instances | |
| - Safe for concurrent use in FastAPI/Starlette applications | |
| Author: IntelliDeep Labs Team | |
| License: BSL 1.1 | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| import logging | |
| import os | |
| import random | |
| import time | |
| from abc import ABC, abstractmethod | |
| from dataclasses import dataclass, field | |
| from datetime import datetime | |
| from enum import Enum, auto | |
| from typing import Any, AsyncGenerator, Callable, Dict, List, Optional | |
| import httpx | |
| from pydantic import field_validator | |
| from nlproxy.utils.constants import PROVIDER_PRICING | |
| try: | |
| from google import genai | |
| from google.genai import types | |
| _GEMINI_AVAILABLE = True | |
| except ImportError: | |
| _GEMINI_AVAILABLE = False | |
| genai = None | |
| types = None | |
| try: | |
| import tiktoken | |
| _TIKTOKEN_AVAILABLE = True | |
| except ImportError: | |
| _TIKTOKEN_AVAILABLE = False | |
| tiktoken = None # type: ignore | |
| # Configure logger with project-standard format | |
| logger = logging.getLogger(__name__) | |
| # ============================================================================= | |
| # CONFIGURATION & PRICING | |
| # ============================================================================= | |
| class ProviderPricing: | |
| """ | |
| Pricing configuration for an LLM provider. | |
| Attributes | |
| ---------- | |
| input_price : float | |
| Cost per 1000 input tokens in USD. | |
| output_price : float | |
| Cost per 1000 output tokens in USD. | |
| currency : str | |
| Currency code (default: "USD"). | |
| """ | |
| input_price: float | |
| output_price: float | |
| currency: str = "USD" | |
| # Provider pricing is imported from shared constants; env vars may override values at import time. | |
| class RetryConfig: | |
| """ | |
| Configuration for retry behavior with exponential backoff. | |
| Attributes | |
| ---------- | |
| max_attempts : int | |
| Maximum number of retry attempts (default: 3). | |
| base_delay : float | |
| Initial delay in seconds before first retry (default: 1.0). | |
| max_delay : float | |
| Maximum delay cap in seconds (default: 30.0). | |
| exponential_base : float | |
| Base for exponential backoff calculation (default: 2.0). | |
| jitter : bool | |
| Whether to add random jitter to delays (default: True). | |
| retryable_exceptions : Tuple[Type[Exception], ...] | |
| Exception types that trigger a retry. | |
| """ | |
| max_attempts: int = 3 | |
| base_delay: float = 1.0 | |
| max_delay: float = 30.0 | |
| exponential_base: float = 2.0 | |
| jitter: bool = True | |
| retryable_exceptions: tuple = ( | |
| httpx.TimeoutException, | |
| httpx.NetworkError, | |
| httpx.HTTPStatusError, | |
| ConnectionError, | |
| asyncio.TimeoutError, | |
| ) | |
| class TimeoutConfig: | |
| """ | |
| Timeout configuration for LLM requests. | |
| Attributes | |
| ---------- | |
| connect : float | |
| Connection timeout in seconds (default: 10.0). | |
| read : float | |
| Read/response timeout in seconds (default: 60.0). | |
| write : float | |
| Write/request timeout in seconds (default: 10.0). | |
| pool : float | |
| Connection pool timeout in seconds (default: 10.0). | |
| """ | |
| connect: float = 10.0 | |
| read: float = 60.0 | |
| write: float = 10.0 | |
| pool: float = 10.0 | |
| class RateLimitConfig: | |
| """ | |
| Rate limiting configuration using token bucket algorithm. | |
| Attributes | |
| ---------- | |
| requests_per_minute : Optional[int] | |
| Maximum requests allowed per minute. None = unlimited. | |
| tokens_per_request : int | |
| Estimated token cost per request for rate calculation. | |
| bucket_capacity : int | |
| Maximum tokens in the bucket (burst capacity). | |
| refill_rate : float | |
| Tokens added per second (sustained rate). | |
| """ | |
| requests_per_minute: Optional[int] = None | |
| tokens_per_request: int = 1000 | |
| bucket_capacity: int = 10000 | |
| refill_rate: float = 100.0 | |
| # ============================================================================= | |
| # DATA MODELS | |
| # ============================================================================= | |
| class LLMProvider(str, Enum): | |
| """ | |
| Supported LLM providers. | |
| Values correspond to provider identifiers used in configuration. | |
| """ | |
| GEMINI = "gemini" | |
| CLAUDE = "claude" | |
| OPENAI = "openai" | |
| DEEPSEEK = "deepseek" | |
| QWEN = "qwen" | |
| KIMI = "kimi" | |
| OPENROUTER = "openrouter" | |
| CUSTOM = "custom" | |
| class RequestStatus(str, Enum): | |
| """Request lifecycle status.""" | |
| PENDING = "pending" | |
| IN_PROGRESS = "in_progress" | |
| COMPLETED = "completed" | |
| FAILED = "failed" | |
| RETRYING = "retrying" | |
| TIMEOUT = "timeout" | |
| RATE_LIMITED = "rate_limited" | |
| CIRCUIT_OPEN = "circuit_open" | |
| class LLMRequest: | |
| """ | |
| Unified request model for all LLM providers. | |
| Attributes | |
| ---------- | |
| prompt : str | |
| Input text to send to the LLM (required). | |
| provider : LLMProvider | |
| Target provider for inference. | |
| model : str | |
| Model identifier (e.g., "gpt-4", "claude-3-opus"). | |
| max_tokens : int | |
| Maximum tokens to generate (default: 512). | |
| temperature : float | |
| Sampling temperature ∈ [0.0, 2.0] (default: 0.7). | |
| top_p : float | |
| Nucleus sampling threshold ∈ [0.0, 1.0] (default: 0.95). | |
| top_k : int | |
| Top-k sampling parameter (default: 40). | |
| stop_sequences : Optional[List[str]] | |
| Sequences that trigger generation stop. | |
| metadata : Optional[Dict[str, Any]] | |
| Additional metadata for logging/tracing. | |
| """ | |
| prompt: str | |
| provider: LLMProvider | |
| model: str | |
| max_tokens: int = 512 | |
| temperature: float = 0.7 | |
| top_p: float = 0.95 | |
| top_k: int = 40 | |
| stop_sequences: Optional[List[str]] = None | |
| metadata: Optional[Dict[str, Any]] = None | |
| def prompt_must_not_be_empty(cls, v: str) -> str: | |
| if not v or not v.strip(): | |
| raise ValueError("prompt must not be empty") | |
| return v.strip() | |
| def temperature_in_range(cls, v: float) -> float: | |
| if not 0.0 <= v <= 2.0: | |
| raise ValueError("temperature must be in [0.0, 2.0]") | |
| return v | |
| def top_p_in_range(cls, v: float) -> float: | |
| if not 0.0 <= v <= 1.0: | |
| raise ValueError("top_p must be in [0.0, 1.0]") | |
| return v | |
| class LLMResponse: | |
| """ | |
| Unified response model from all LLM providers. | |
| Attributes | |
| ---------- | |
| text : str | |
| Generated text output. | |
| provider : LLMProvider | |
| Provider that generated the response. | |
| model : str | |
| Model that generated the response. | |
| input_tokens : int | |
| Number of tokens in the input prompt. | |
| output_tokens : int | |
| Number of tokens in the generated output. | |
| latency_ms : float | |
| End-to-end latency in milliseconds. | |
| cost_usd : float | |
| Estimated cost in USD based on provider pricing. | |
| metadata : Dict[str, Any] | |
| Provider-specific metadata (finish_reason, logprobs, etc.). | |
| request_id : str | |
| Unique identifier for tracing/logging. | |
| timestamp : datetime | |
| Response generation timestamp. | |
| """ | |
| text: str | |
| provider: LLMProvider | |
| model: str | |
| input_tokens: int | |
| output_tokens: int | |
| latency_ms: float | |
| cost_usd: float | |
| metadata: Dict[str, Any] = field(default_factory=dict) | |
| request_id: str = "" | |
| timestamp: datetime = field(default_factory=datetime.utcnow) | |
| class LLMError: | |
| """ | |
| Structured error information for failed requests. | |
| Attributes | |
| ---------- | |
| message : str | |
| Human-readable error description. | |
| error_type : str | |
| Categorized error type (timeout, rate_limit, auth, etc.). | |
| provider : LLMProvider | |
| Provider where error occurred. | |
| model : Optional[str] | |
| Model involved (if applicable). | |
| retryable : bool | |
| Whether the error is transient and retryable. | |
| details : Optional[Dict[str, Any]] | |
| Additional error context (status code, response body, etc.). | |
| timestamp : datetime | |
| Error occurrence timestamp. | |
| """ | |
| message: str | |
| error_type: str | |
| provider: LLMProvider | |
| model: Optional[str] = None | |
| retryable: bool = False | |
| details: Optional[Dict[str, Any]] = None | |
| timestamp: datetime = field(default_factory=datetime.utcnow) | |
| # ============================================================================= | |
| # CIRCUIT BREAKER (Improved: distinguishes retryable vs non-retryable) | |
| # ============================================================================= | |
| class CircuitState(Enum): | |
| """Circuit breaker states.""" | |
| CLOSED = auto() | |
| OPEN = auto() | |
| HALF_OPEN = auto() | |
| class CircuitBreaker: | |
| """ | |
| Circuit breaker for fault tolerance in LLM provider calls. | |
| IMPROVEMENT: Distinguishes between retryable and non-retryable errors. | |
| Only non-retryable errors or persistent failures after retries increment | |
| the failure count that can trip the circuit. | |
| State Transitions: | |
| ------------------ | |
| CLOSED → OPEN: When non-retryable failures ≥ threshold within window | |
| OPEN → HALF_OPEN: After recovery_timeout expires | |
| HALF_OPEN → CLOSED: When success_count ≥ success_threshold | |
| HALF_OPEN → OPEN: On any non-retryable failure during testing | |
| """ | |
| def __init__( | |
| self, | |
| failure_threshold: int = 5, | |
| recovery_timeout: float = 30.0, | |
| success_threshold: int = 3, | |
| window_seconds: float = 60.0, | |
| ): | |
| self.failure_threshold = failure_threshold | |
| self.recovery_timeout = recovery_timeout | |
| self.success_threshold = success_threshold | |
| self.window_seconds = window_seconds | |
| self._state = CircuitState.CLOSED | |
| self._failure_count = 0 | |
| self._success_count = 0 | |
| self._last_failure_time: Optional[float] = None | |
| self._lock = asyncio.Lock() | |
| def state(self) -> CircuitState: | |
| """Current circuit state with automatic OPEN → HALF_OPEN transition.""" | |
| if self._state == CircuitState.OPEN: | |
| if self._last_failure_time: | |
| elapsed = time.time() - self._last_failure_time | |
| if elapsed >= self.recovery_timeout: | |
| return CircuitState.HALF_OPEN | |
| return self._state | |
| async def can_execute(self) -> bool: | |
| """Check if a request can proceed through the circuit.""" | |
| async with self._lock: | |
| current_state = self.state | |
| if current_state == CircuitState.CLOSED: | |
| return True | |
| elif current_state == CircuitState.HALF_OPEN: | |
| return True | |
| else: | |
| return False | |
| async def record_success(self) -> None: | |
| """Record a successful request.""" | |
| async with self._lock: | |
| if self._state == CircuitState.HALF_OPEN: | |
| self._success_count += 1 | |
| if self._success_count >= self.success_threshold: | |
| self._reset() | |
| logger.info("Circuit breaker CLOSED after successful recovery") | |
| elif self._state == CircuitState.CLOSED: | |
| self._failure_count = 0 | |
| async def record_failure(self, retryable: bool = False) -> None: | |
| """ | |
| Record a failed request. | |
| IMPROVEMENT: Only non-retryable errors increment the failure count | |
| that can trip the circuit breaker. Retryable errors are logged but | |
| don't affect circuit state unless they persist after max retries. | |
| Parameters | |
| ---------- | |
| retryable : bool | |
| Whether the error is transient and retryable. | |
| """ | |
| async with self._lock: | |
| now = time.time() | |
| # Remove old failures outside the window | |
| if self._last_failure_time: | |
| if now - self._last_failure_time > self.window_seconds: | |
| self._failure_count = 0 | |
| # Only increment failure count for non-retryable errors | |
| if not retryable: | |
| self._failure_count += 1 | |
| self._last_failure_time = now | |
| if self._state == CircuitState.HALF_OPEN: | |
| self._state = CircuitState.OPEN | |
| logger.warning("Circuit breaker re-OPENED after non-retryable failure in HALF_OPEN") | |
| elif self._state == CircuitState.CLOSED: | |
| if self._failure_count >= self.failure_threshold: | |
| self._state = CircuitState.OPEN | |
| logger.warning( | |
| f"Circuit breaker OPENED: {self._failure_count} non-retryable failures " | |
| f"in {self.window_seconds}s window" | |
| ) | |
| else: | |
| logger.debug(f"Retryable error recorded; circuit breaker state unchanged: {self._state.name}") | |
| def _reset(self) -> None: | |
| """Reset circuit breaker to initial closed state.""" | |
| self._state = CircuitState.CLOSED | |
| self._failure_count = 0 | |
| self._success_count = 0 | |
| self._last_failure_time = None | |
| def get_stats(self) -> Dict[str, Any]: | |
| """Return circuit breaker statistics for monitoring.""" | |
| return { | |
| "state": self.state.name, | |
| "failure_count": self._failure_count, | |
| "success_count": self._success_count, | |
| "last_failure_time": self._last_failure_time, | |
| "failure_threshold": self.failure_threshold, | |
| "recovery_timeout": self.recovery_timeout, | |
| } | |
| # ============================================================================= | |
| # RATE LIMITER + CONCURRENCY SEMAPHORE | |
| # ============================================================================= | |
| class TokenBucket: | |
| """ | |
| Token bucket rate limiter for controlling request throughput. | |
| Allows burst traffic up to bucket_capacity while sustaining | |
| refill_rate tokens per second long-term. | |
| """ | |
| def __init__(self, capacity: float, refill_rate: float): | |
| self.capacity = capacity | |
| self.refill_rate = refill_rate | |
| self._tokens = capacity | |
| self._last_refill = time.time() | |
| self._lock = asyncio.Lock() | |
| async def acquire(self, tokens: float = 1.0, timeout: Optional[float] = None) -> bool: | |
| """Attempt to acquire tokens from the bucket.""" | |
| start_time = time.time() | |
| while True: | |
| async with self._lock: | |
| now = time.time() | |
| elapsed = now - self._last_refill | |
| self._tokens = min( | |
| self.capacity, | |
| self._tokens + self.refill_rate * elapsed | |
| ) | |
| self._last_refill = now | |
| if self._tokens >= tokens: | |
| self._tokens -= tokens | |
| return True | |
| needed = tokens - self._tokens | |
| wait_time = needed / self.refill_rate | |
| if timeout is not None: | |
| elapsed_total = time.time() - start_time | |
| if elapsed_total + wait_time > timeout: | |
| return False | |
| await asyncio.sleep(min(wait_time, 0.1)) | |
| def get_stats(self) -> Dict[str, float]: | |
| """Return current bucket statistics.""" | |
| return { | |
| "available_tokens": self._tokens, | |
| "capacity": self.capacity, | |
| "refill_rate": self.refill_rate, | |
| } | |
| # ============================================================================= | |
| # PROVIDER-SPECIFIC TOKENIZERS | |
| # ============================================================================= | |
| class TokenCounter(ABC): | |
| """ | |
| Abstract base class for provider-specific token counting. | |
| IMPROVEMENT: Each provider can implement accurate token counting | |
| using their native tokenizer or API endpoint. | |
| """ | |
| def count_tokens(self, text: str) -> int: | |
| """Count tokens in text using provider-specific method.""" | |
| pass | |
| class OpenAITokenCounter(TokenCounter): | |
| """Token counter for OpenAI models using tiktoken.""" | |
| def __init__(self, model: str): | |
| if _TIKTOKEN_AVAILABLE: | |
| try: | |
| self.encoding = tiktoken.encoding_for_model(model) | |
| except KeyError: | |
| self.encoding = tiktoken.get_encoding("cl100k_base") | |
| else: | |
| self.encoding = None | |
| def count_tokens(self, text: str) -> int: | |
| if self.encoding: | |
| return len(self.encoding.encode(text)) | |
| return len(text) // 4 | |
| class ClaudeTokenCounter(TokenCounter): | |
| """ | |
| Token counter for Anthropic Claude models. | |
| Uses Anthropic's official tokenizer when available, falls back to | |
| tiktoken with cl100k_base encoding (close approximation). | |
| """ | |
| def __init__(self, model: str): | |
| self.model = model | |
| # Try to import Anthropic's tokenizer | |
| try: | |
| from anthropic import Anthropic | |
| self._client = Anthropic(api_key="dummy") # Tokenizer doesn't need valid key | |
| self._has_native = True | |
| except (ImportError, Exception): | |
| self._has_native = False | |
| if _TIKTOKEN_AVAILABLE: | |
| self.encoding = tiktoken.get_encoding("cl100k_base") | |
| else: | |
| self.encoding = None | |
| def count_tokens(self, text: str) -> int: | |
| if self._has_native: | |
| try: | |
| return self._client.count_tokens(text) | |
| except Exception: | |
| pass | |
| if self.encoding: | |
| return len(self.encoding.encode(text)) | |
| return len(text) // 4 | |
| class GeminiTokenCounter(TokenCounter): | |
| """ | |
| Token counter for Google Gemini models. | |
| Uses google.generativeai.count_tokens when available. | |
| """ | |
| def __init__(self, model: str): | |
| self.model = model | |
| self._has_native = _GEMINI_AVAILABLE | |
| def count_tokens(self, text: str) -> int: | |
| if self._has_native and genai: | |
| try: | |
| return genai.count_tokens(text).total_tokens | |
| except Exception: | |
| pass | |
| if _TIKTOKEN_AVAILABLE: | |
| enc = tiktoken.get_encoding("cl100k_base") | |
| return len(enc.encode(text)) | |
| return len(text) // 4 | |
| class GenericTokenCounter(TokenCounter): | |
| """ | |
| Fallback token counter for providers without native tokenizer. | |
| Uses tiktoken with cl100k_base encoding as best-effort approximation. | |
| """ | |
| def __init__(self, model: str): | |
| if _TIKTOKEN_AVAILABLE: | |
| self.encoding = tiktoken.get_encoding("cl100k_base") | |
| else: | |
| self.encoding = None | |
| def count_tokens(self, text: str) -> int: | |
| if self.encoding: | |
| return len(self.encoding.encode(text)) | |
| return len(text) // 4 | |
| def get_token_counter(provider: LLMProvider, model: str) -> TokenCounter: | |
| """ | |
| Factory function to get appropriate token counter for provider. | |
| Parameters | |
| ---------- | |
| provider : LLMProvider | |
| Target provider. | |
| model : str | |
| Model identifier. | |
| Returns | |
| ------- | |
| TokenCounter | |
| Provider-specific token counter instance. | |
| """ | |
| if provider == LLMProvider.OPENAI: | |
| return OpenAITokenCounter(model) | |
| elif provider == LLMProvider.CLAUDE: | |
| return ClaudeTokenCounter(model) | |
| elif provider == LLMProvider.GEMINI: | |
| return GeminiTokenCounter(model) | |
| else: | |
| return GenericTokenCounter(model) | |
| # ============================================================================= | |
| # ABSTRACT PROVIDER CLIENT | |
| # ============================================================================= | |
| class BaseLLMClient(ABC): | |
| """ | |
| Abstract base class for LLM provider clients. | |
| Defines the interface that all provider implementations must follow. | |
| """ | |
| PROVIDER: LLMProvider | |
| DEFAULT_MODEL: str | |
| def __init__( | |
| self, | |
| api_key: str, | |
| model: Optional[str] = None, | |
| retry_config: RetryConfig = RetryConfig(), | |
| timeout_config: TimeoutConfig = TimeoutConfig(), | |
| rate_limit_config: Optional[RateLimitConfig] = None, | |
| circuit_breaker: Optional[CircuitBreaker] = None, | |
| base_url: Optional[str] = None, | |
| max_concurrent_requests: int = 10, | |
| ): | |
| self.api_key = api_key | |
| self.model = model | |
| self.retry_config = retry_config | |
| self.timeout_config = timeout_config | |
| self.rate_limit_config = rate_limit_config | |
| self.circuit_breaker = circuit_breaker or CircuitBreaker() | |
| self.base_url = base_url | |
| self.max_concurrent_requests = max_concurrent_requests | |
| # Initialize rate limiter if configured | |
| self._rate_limiter: Optional[TokenBucket] = None | |
| if rate_limit_config and rate_limit_config.requests_per_minute: | |
| tokens_per_sec = ( | |
| rate_limit_config.requests_per_minute * | |
| rate_limit_config.tokens_per_request / 60.0 | |
| ) | |
| self._rate_limiter = TokenBucket( | |
| capacity=rate_limit_config.bucket_capacity, | |
| refill_rate=tokens_per_sec | |
| ) | |
| # IMPROVEMENT: Semaphore for concurrency limiting per provider | |
| self._concurrency_semaphore = asyncio.Semaphore(max_concurrent_requests) | |
| # HTTP client for providers using REST API | |
| self._http_client: Optional[httpx.AsyncClient] = None | |
| # IMPROVEMENT: Provider-specific token counter | |
| self._token_counter: Optional[TokenCounter] = None | |
| # Metrics tracking | |
| self._request_count = 0 | |
| self._error_count = 0 | |
| self._total_latency_ms = 0.0 | |
| async def _generate_internal(self, request: LLMRequest) -> LLMResponse: | |
| """Internal generation logic specific to the provider.""" | |
| pass | |
| async def _generate_stream_internal( | |
| self, request: LLMRequest | |
| ) -> AsyncGenerator[str, None]: | |
| """Streaming generation logic specific to the provider.""" | |
| pass | |
| def _init_token_counter(self, model: str) -> None: | |
| """Initialize provider-specific token counter.""" | |
| self._token_counter = get_token_counter(self.PROVIDER, model) | |
| def _count_tokens(self, text: str) -> int: | |
| """Count tokens using provider-specific method.""" | |
| if self._token_counter: | |
| return self._token_counter.count_tokens(text) | |
| # Fallback | |
| if _TIKTOKEN_AVAILABLE: | |
| enc = tiktoken.get_encoding("cl100k_base") | |
| return len(enc.encode(text)) | |
| return len(text) // 4 | |
| def _ensure_http_client( | |
| self, | |
| base_url: str, | |
| headers: Dict[str, str], | |
| ) -> httpx.AsyncClient: | |
| """Return a single shared AsyncClient instance for this provider client.""" | |
| if self._http_client is None: | |
| self._http_client = httpx.AsyncClient( | |
| base_url=base_url, | |
| timeout=httpx.Timeout( | |
| connect=self.timeout_config.connect, | |
| read=self.timeout_config.read, | |
| write=self.timeout_config.write, | |
| pool=self.timeout_config.pool, | |
| ), | |
| headers=headers, | |
| ) | |
| return self._http_client | |
| def _get_pricing_key(self, model: str) -> str: | |
| """Get pricing key for cost calculation.""" | |
| pass | |
| async def _apply_rate_limit(self, tokens: float = 1.0) -> None: | |
| """Apply rate limiting if configured.""" | |
| if self._rate_limiter: | |
| acquired = await self._rate_limiter.acquire( | |
| tokens=tokens, timeout=30.0 | |
| ) | |
| if not acquired: | |
| raise LLMProviderError( | |
| message="Rate limit timeout: could not acquire tokens", | |
| error_type="rate_limit", | |
| provider=self.PROVIDER, | |
| retryable=True, | |
| ) | |
| async def _acquire_concurrency_slot(self) -> None: | |
| """Acquire a slot from the concurrency semaphore.""" | |
| await self._concurrency_semaphore.acquire() | |
| def _release_concurrency_slot(self) -> None: | |
| """Release a slot from the concurrency semaphore.""" | |
| self._concurrency_semaphore.release() | |
| async def _with_retry( | |
| self, | |
| operation: Callable[[], Any], | |
| operation_name: str, | |
| ) -> Any: | |
| """Execute operation with retry logic and exponential backoff.""" | |
| last_error: Optional[Exception] = None | |
| last_retryable = True # Track if last error was retryable | |
| for attempt in range(1, self.retry_config.max_attempts + 1): | |
| try: | |
| logger.debug( | |
| f"[{self.PROVIDER.value}] {operation_name} attempt {attempt}/" | |
| f"{self.retry_config.max_attempts}" | |
| ) | |
| return await operation() | |
| except self.retry_config.retryable_exceptions as e: | |
| last_error = e | |
| last_retryable = True | |
| logger.warning( | |
| f"[{self.PROVIDER.value}] {operation_name} failed (attempt {attempt}): " | |
| f"{type(e).__name__}: {e}" | |
| ) | |
| if attempt < self.retry_config.max_attempts: | |
| delay = min( | |
| self.retry_config.base_delay * | |
| (self.retry_config.exponential_base ** (attempt - 1)), | |
| self.retry_config.max_delay | |
| ) | |
| if self.retry_config.jitter: | |
| jitter = random.uniform(0, 0.1 * self.retry_config.base_delay) | |
| delay += jitter | |
| logger.info(f"[{self.PROVIDER.value}] Retrying in {delay:.2f}s...") | |
| await asyncio.sleep(delay) | |
| else: | |
| logger.error( | |
| f"[{self.PROVIDER.value}] {operation_name} failed after " | |
| f"{self.retry_config.max_attempts} attempts" | |
| ) | |
| except Exception as e: | |
| last_error = e | |
| last_retryable = False | |
| logger.error( | |
| f"[{self.PROVIDER.value}] {operation_name} failed with " | |
| f"non-retryable error: {type(e).__name__}: {e}" | |
| ) | |
| raise LLMProviderError( | |
| message=str(e), | |
| error_type=type(e).__name__.lower(), | |
| provider=self.PROVIDER, | |
| retryable=False, | |
| details={"original_exception": str(e)}, | |
| ) | |
| # Exhausted retries - record failure for circuit breaker | |
| await self.circuit_breaker.record_failure(retryable=last_retryable) | |
| raise LLMProviderError( | |
| message=f"{operation_name} failed after {self.retry_config.max_attempts} attempts", | |
| error_type="max_retries_exceeded", | |
| provider=self.PROVIDER, | |
| retryable=False, | |
| details={"last_error": str(last_error)}, | |
| ) | |
| async def generate(self, request: LLMRequest) -> LLMResponse: | |
| """ | |
| Generate text from LLM with full error handling and metrics. | |
| IMPROVEMENTS: | |
| - Uses provider-specific token counter for accurate cost estimation | |
| - Concurrency limiting via semaphore | |
| - Circuit breaker only trips on non-retryable errors | |
| """ | |
| request_id = f"{self.PROVIDER.value}-{int(time.time() * 1000)}-{random.randint(1000, 9999)}" | |
| start_time = time.time() | |
| logger.info( | |
| f"[{request_id}] Starting {self.PROVIDER.value} request: " | |
| f"model={request.model}, prompt_length={len(request.prompt)}" | |
| ) | |
| try: | |
| # Check circuit breaker | |
| if not await self.circuit_breaker.can_execute(): | |
| raise LLMProviderError( | |
| message="Circuit breaker is OPEN; provider unavailable", | |
| error_type="circuit_open", | |
| provider=self.PROVIDER, | |
| model=request.model, | |
| retryable=True, | |
| ) | |
| # Acquire concurrency slot | |
| await self._acquire_concurrency_slot() | |
| try: | |
| # Apply rate limiting | |
| input_tokens = self._count_tokens(request.prompt) | |
| await self._apply_rate_limit(tokens=input_tokens / 1000) | |
| # Execute generation with retry | |
| response = await self._with_retry( | |
| lambda: self._generate_internal(request), | |
| operation_name=f"generate({request.model})" | |
| ) | |
| # Record success | |
| await self.circuit_breaker.record_success() | |
| # Update metrics | |
| latency_ms = (time.time() - start_time) * 1000 | |
| self._request_count += 1 | |
| self._total_latency_ms += latency_ms | |
| # Add request metadata to response | |
| response.request_id = request_id | |
| response.input_tokens = input_tokens | |
| response.output_tokens = self._count_tokens(response.text) | |
| response.latency_ms = latency_ms | |
| # Calculate cost | |
| pricing_key = self._get_pricing_key(request.model) | |
| pricing = PROVIDER_PRICING.get( | |
| pricing_key, | |
| PROVIDER_PRICING.get("custom/*", {"input": 0.0, "output": 0.0}) | |
| ) | |
| response.cost_usd = ( | |
| response.input_tokens * pricing.get("input", 0.0) + | |
| response.output_tokens * pricing.get("output", 0.0) | |
| ) / 1000 | |
| logger.info( | |
| f"[{request_id}] {self.PROVIDER.value} success: " | |
| f"latency={latency_ms:.0f}ms, tokens_in={response.input_tokens}, " | |
| f"tokens_out={response.output_tokens}, cost=${response.cost_usd:.4f}" | |
| ) | |
| return response | |
| finally: | |
| self._release_concurrency_slot() | |
| except LLMProviderError as e: | |
| await self.circuit_breaker.record_failure(retryable=e.retryable) | |
| self._error_count += 1 | |
| logger.error( | |
| f"[{request_id}] {self.PROVIDER.value} failed: " | |
| f"{e.error_type} - {e.message}" | |
| ) | |
| raise | |
| except Exception as e: | |
| await self.circuit_breaker.record_failure(retryable=False) | |
| self._error_count += 1 | |
| logger.exception( | |
| f"[{request_id}] {self.PROVIDER.value} unexpected error: {e}" | |
| ) | |
| raise LLMProviderError( | |
| message=f"Unexpected error: {str(e)}", | |
| error_type="unexpected", | |
| provider=self.PROVIDER, | |
| model=request.model, | |
| retryable=False, | |
| ) | |
| async def generate_stream( | |
| self, request: LLMRequest | |
| ) -> AsyncGenerator[Dict[str, Any], None]: | |
| """ | |
| Generate text with streaming support. | |
| Yields partial responses as tokens are generated. | |
| Parameters | |
| ---------- | |
| request : LLMRequest | |
| Validated generation request. | |
| Yields | |
| ------ | |
| Dict[str, Any] | |
| Streaming chunks with keys: | |
| - text: partial/generated text | |
| - finished: bool indicating end of stream | |
| - metadata: provider-specific streaming metadata | |
| """ | |
| request_id = f"{self.PROVIDER.value}-stream-{int(time.time() * 1000)}" | |
| start_time = time.time() | |
| logger.info( | |
| f"[{request_id}] Starting {self.PROVIDER.value} streaming request: " | |
| f"model={request.model}" | |
| ) | |
| try: | |
| if not await self.circuit_breaker.can_execute(): | |
| raise LLMProviderError( | |
| message="Circuit breaker is OPEN; provider unavailable", | |
| error_type="circuit_open", | |
| provider=self.PROVIDER, | |
| model=request.model, | |
| retryable=True, | |
| ) | |
| await self._acquire_concurrency_slot() | |
| try: | |
| input_tokens = self._count_tokens(request.prompt) | |
| await self._apply_rate_limit(tokens=input_tokens / 1000) | |
| accumulated_text = "" | |
| async for chunk in self._generate_stream_internal(request): | |
| accumulated_text += chunk | |
| yield { | |
| "text": chunk, | |
| "accumulated": accumulated_text, | |
| "finished": False, | |
| "metadata": {}, | |
| } | |
| # Final chunk with completion metadata | |
| latency_ms = (time.time() - start_time) * 1000 | |
| output_tokens = self._count_tokens(accumulated_text) | |
| pricing_key = self._get_pricing_key(request.model) | |
| pricing = PROVIDER_PRICING.get( | |
| pricing_key, | |
| PROVIDER_PRICING.get("custom/*", {"input": 0.0, "output": 0.0}) | |
| ) | |
| cost_usd = ( | |
| input_tokens * pricing.get("input", 0.0) + | |
| output_tokens * pricing.get("output", 0.0) | |
| ) / 1000 | |
| yield { | |
| "text": "", | |
| "accumulated": accumulated_text, | |
| "finished": True, | |
| "metadata": { | |
| "input_tokens": input_tokens, | |
| "output_tokens": output_tokens, | |
| "latency_ms": latency_ms, | |
| "cost_usd": cost_usd, | |
| "request_id": request_id, | |
| }, | |
| } | |
| await self.circuit_breaker.record_success() | |
| self._request_count += 1 | |
| self._total_latency_ms += latency_ms | |
| logger.info( | |
| f"[{request_id}] {self.PROVIDER.value} streaming complete: " | |
| f"latency={latency_ms:.0f}ms, tokens_out={output_tokens}" | |
| ) | |
| finally: | |
| self._release_concurrency_slot() | |
| except LLMProviderError as e: | |
| await self.circuit_breaker.record_failure(retryable=e.retryable) | |
| self._error_count += 1 | |
| logger.error( | |
| f"[{request_id}] {self.PROVIDER.value} streaming failed: " | |
| f"{e.error_type} - {e.message}" | |
| ) | |
| raise | |
| except Exception as e: | |
| await self.circuit_breaker.record_failure(retryable=False) | |
| self._error_count += 1 | |
| logger.exception( | |
| f"[{request_id}] {self.PROVIDER.value} streaming unexpected error: {e}" | |
| ) | |
| raise LLMProviderError( | |
| message=f"Unexpected error: {str(e)}", | |
| error_type="unexpected", | |
| provider=self.PROVIDER, | |
| model=request.model, | |
| retryable=False, | |
| ) | |
| def get_metrics(self) -> Dict[str, Any]: | |
| """Return client metrics for monitoring.""" | |
| avg_latency = ( | |
| self._total_latency_ms / self._request_count | |
| if self._request_count > 0 else 0.0 | |
| ) | |
| error_rate = ( | |
| self._error_count / self._request_count | |
| if self._request_count > 0 else 0.0 | |
| ) | |
| return { | |
| "provider": self.PROVIDER.value, | |
| "request_count": self._request_count, | |
| "error_count": self._error_count, | |
| "error_rate": error_rate, | |
| "avg_latency_ms": avg_latency, | |
| "circuit_breaker": self.circuit_breaker.get_stats(), | |
| "rate_limiter": ( | |
| self._rate_limiter.get_stats() | |
| if self._rate_limiter else None | |
| ), | |
| "concurrency": { | |
| "max_concurrent": self.max_concurrent_requests, | |
| "available_slots": self._concurrency_semaphore._value, | |
| }, | |
| } | |
| async def health_check(self) -> Dict[str, Any]: | |
| """ | |
| Perform a lightweight health check for the provider. | |
| IMPROVEMENT: Uses more generous timeout (20s) and minimal prompt | |
| to reduce false negatives during high load. | |
| """ | |
| try: | |
| test_request = LLMRequest( | |
| prompt="ping", | |
| provider=self.PROVIDER, | |
| model=self.DEFAULT_MODEL, | |
| max_tokens=5, # Minimal tokens for faster response | |
| temperature=0.0, | |
| ) | |
| start = time.time() | |
| # IMPROVEMENT: More generous timeout for health checks | |
| response = await asyncio.wait_for( | |
| self._generate_internal(test_request), | |
| timeout=20.0 # Increased from 10s to 20s | |
| ) | |
| latency_ms = (time.time() - start) * 1000 | |
| return { | |
| "status": "healthy", | |
| "latency_ms": latency_ms, | |
| "provider": self.PROVIDER.value, | |
| "model": self.DEFAULT_MODEL, | |
| "timestamp": datetime.utcnow().isoformat(), | |
| } | |
| except asyncio.TimeoutError: | |
| return { | |
| "status": "timeout", | |
| "provider": self.PROVIDER.value, | |
| "error": "Health check timed out (20s limit)", | |
| } | |
| except Exception as e: | |
| return { | |
| "status": "unhealthy", | |
| "provider": self.PROVIDER.value, | |
| "error": str(e), | |
| "error_type": type(e).__name__, | |
| } | |
| async def close(self) -> None: | |
| """Clean up resources (HTTP connections, etc.).""" | |
| if self._http_client: | |
| await self._http_client.aclose() | |
| self._http_client = None | |
| # ============================================================================= | |
| # PROVIDER IMPLEMENTATIONS | |
| # ============================================================================= | |
| class GeminiClient(BaseLLMClient): | |
| PROVIDER = LLMProvider.GEMINI | |
| DEFAULT_MODEL = "gemini-pro" | |
| def __init__(self, api_key: str, **kwargs): | |
| if not _GEMINI_AVAILABLE: | |
| raise ImportError( | |
| "google-genai not installed. Install with: pip install google-genai" | |
| ) | |
| model = kwargs.pop('model', None) | |
| super().__init__(api_key, model=model, **kwargs) | |
| self._client = genai.Client(api_key=api_key) | |
| self._init_token_counter(kwargs.get("model", self.DEFAULT_MODEL)) | |
| def _get_pricing_key(self, model: str) -> str: | |
| return f"google/{model}" | |
| async def _generate_internal(self, request: LLMRequest) -> LLMResponse: | |
| try: | |
| response = await self._client.aio.models.generate_content( | |
| model=request.model, | |
| contents=request.prompt, | |
| config=types.GenerateContentConfig( | |
| max_output_tokens=request.max_tokens, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| top_k=request.top_k, | |
| stop_sequences=request.stop_sequences or [], | |
| ) | |
| ) | |
| except Exception as e: | |
| raise LLMProviderError( | |
| message=str(e), | |
| error_type=type(e).__name__.lower(), | |
| provider=self.PROVIDER, | |
| model=request.model, | |
| retryable=True, | |
| ) | |
| if not response.text: | |
| raise LLMProviderError( | |
| message="Gemini returned empty response", | |
| error_type="empty_response", | |
| provider=self.PROVIDER, | |
| model=request.model, | |
| retryable=True, | |
| ) | |
| # Extract metadata from the response if available (provider-specific) | |
| metadata = {} | |
| if hasattr(response, 'usage_metadata'): | |
| metadata["usage"] = { | |
| "prompt_tokens": response.usage_metadata.prompt_token_count, | |
| "candidates_tokens": response.usage_metadata.candidates_token_count, | |
| "total_tokens": response.usage_metadata.total_token_count, | |
| } | |
| return LLMResponse( | |
| text=response.text, | |
| provider=self.PROVIDER, | |
| model=request.model, | |
| input_tokens=0, | |
| output_tokens=0, | |
| latency_ms=0, | |
| cost_usd=0, | |
| metadata=metadata, | |
| ) | |
| async def _generate_stream_internal( | |
| self, request: LLMRequest | |
| ) -> AsyncGenerator[str, None]: | |
| """Streaming generation for Gemini using new API.""" | |
| async for chunk in await self._client.aio.models.generate_content_stream( | |
| model=request.model, | |
| contents=request.prompt, | |
| config=types.GenerateContentConfig( | |
| max_output_tokens=request.max_tokens, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| top_k=request.top_k, | |
| ) | |
| ): | |
| if chunk.text: | |
| yield chunk.text | |
| def _count_tokens(self, text: str) -> int: | |
| if self._token_counter: | |
| return self._token_counter.count_tokens(text) | |
| if _TIKTOKEN_AVAILABLE: | |
| enc = tiktoken.get_encoding("cl100k_base") | |
| return len(enc.encode(text)) | |
| return len(text) // 4 | |
| class ClaudeClient(BaseLLMClient): | |
| """Client for Anthropic Claude API.""" | |
| PROVIDER = LLMProvider.CLAUDE | |
| DEFAULT_MODEL = "claude-3-sonnet-20240229" | |
| def __init__(self, api_key: str, **kwargs): | |
| model = kwargs.pop('model', None) | |
| super().__init__(api_key, model=model, **kwargs) | |
| self.base_url = self.base_url or "https://api.anthropic.com" | |
| self._init_token_counter(kwargs.get("model", self.DEFAULT_MODEL)) | |
| def _get_pricing_key(self, model: str) -> str: | |
| return f"anthropic/{model}" | |
| async def _generate_internal(self, request: LLMRequest) -> LLMResponse: | |
| self._http_client = self._ensure_http_client( | |
| self.base_url, | |
| { | |
| "x-api-key": self.api_key, | |
| "anthropic-version": "2023-06-01", | |
| "content-type": "application/json", | |
| }, | |
| ) | |
| payload = { | |
| "model": request.model, | |
| "messages": [{"role": "user", "content": request.prompt}], | |
| "max_tokens": request.max_tokens, | |
| "temperature": request.temperature, | |
| "top_p": request.top_p, | |
| "top_k": request.top_k, | |
| } | |
| if request.stop_sequences: | |
| payload["stop_sequences"] = request.stop_sequences | |
| response = await self._http_client.post("/v1/messages", json=payload) | |
| response.raise_for_status() | |
| data = response.json() | |
| if not data.get("content"): | |
| raise LLMProviderError( | |
| message="Claude returned empty response", | |
| error_type="empty_response", | |
| provider=self.PROVIDER, | |
| model=request.model, | |
| retryable=True, | |
| ) | |
| text = data["content"][0]["text"] | |
| return LLMResponse( | |
| text=text, | |
| provider=self.PROVIDER, | |
| model=request.model, | |
| input_tokens=0, | |
| output_tokens=0, | |
| latency_ms=0, | |
| cost_usd=0, | |
| metadata={ | |
| "stop_reason": data.get("stop_reason"), | |
| "model": data.get("model"), | |
| "usage": data.get("usage", {}), | |
| }, | |
| ) | |
| async def _generate_stream_internal( | |
| self, request: LLMRequest | |
| ) -> AsyncGenerator[str, None]: | |
| """Streaming generation for Claude.""" | |
| self._http_client = self._ensure_http_client( | |
| self.base_url, | |
| { | |
| "x-api-key": self.api_key, | |
| "anthropic-version": "2023-06-01", | |
| "content-type": "application/json", | |
| }, | |
| ) | |
| payload = { | |
| "model": request.model, | |
| "messages": [{"role": "user", "content": request.prompt}], | |
| "max_tokens": request.max_tokens, | |
| "temperature": request.temperature, | |
| "top_p": request.top_p, | |
| "stream": True, | |
| } | |
| async with self._http_client.stream("POST", "/v1/messages", json=payload) as response: | |
| response.raise_for_status() | |
| async for line in response.aiter_lines(): | |
| if line.startswith("data: "): | |
| data = line[6:].strip() | |
| if data == "[DONE]": | |
| break | |
| try: | |
| chunk = json.loads(data) | |
| if chunk.get("type") == "content_block_delta": | |
| text = chunk.get("delta", {}).get("text", "") | |
| if text: | |
| yield text | |
| except json.JSONDecodeError: | |
| continue | |
| def _count_tokens(self, text: str) -> int: | |
| """Use Claude's native token counting when available.""" | |
| if self._token_counter: | |
| return self._token_counter.count_tokens(text) | |
| # Fallback to tiktoken | |
| if _TIKTOKEN_AVAILABLE: | |
| enc = tiktoken.get_encoding("cl100k_base") | |
| return len(enc.encode(text)) | |
| return len(text) // 4 | |
| class OpenAIClient(BaseLLMClient): | |
| """Client for OpenAI API.""" | |
| PROVIDER = LLMProvider.OPENAI | |
| DEFAULT_MODEL = "gpt-4" | |
| def __init__(self, api_key: str, **kwargs): | |
| model = kwargs.pop('model', None) | |
| super().__init__(api_key, model=model, **kwargs) | |
| self.base_url = self.base_url or "https://api.openai.com/v1" | |
| self._init_token_counter(kwargs.get("model", self.DEFAULT_MODEL)) | |
| def _get_pricing_key(self, model: str) -> str: | |
| return f"openai/{model}" | |
| async def _generate_internal(self, request: LLMRequest) -> LLMResponse: | |
| self._http_client = self._ensure_http_client( | |
| self.base_url, | |
| { | |
| "Authorization": f"Bearer {self.api_key}", | |
| "content-type": "application/json", | |
| }, | |
| ) | |
| payload = { | |
| "model": request.model, | |
| "messages": [{"role": "user", "content": request.prompt}], | |
| "max_tokens": request.max_tokens, | |
| "temperature": request.temperature, | |
| "top_p": request.top_p, | |
| } | |
| if request.stop_sequences: | |
| payload["stop"] = request.stop_sequences | |
| response = await self._http_client.post("/chat/completions", json=payload) | |
| response.raise_for_status() | |
| data = response.json() | |
| if not data.get("choices") or not data["choices"][0].get("message"): | |
| raise LLMProviderError( | |
| message="OpenAI returned empty response", | |
| error_type="empty_response", | |
| provider=self.PROVIDER, | |
| model=request.model, | |
| retryable=True, | |
| ) | |
| text = data["choices"][0]["message"]["content"] | |
| return LLMResponse( | |
| text=text, | |
| provider=self.PROVIDER, | |
| model=request.model, | |
| input_tokens=0, | |
| output_tokens=0, | |
| latency_ms=0, | |
| cost_usd=0, | |
| metadata={ | |
| "finish_reason": data["choices"][0].get("finish_reason"), | |
| "usage": data.get("usage", {}), | |
| "system_fingerprint": data.get("system_fingerprint"), | |
| }, | |
| ) | |
| async def _generate_stream_internal( | |
| self, request: LLMRequest | |
| ) -> AsyncGenerator[str, None]: | |
| """Streaming generation for OpenAI.""" | |
| self._http_client = self._ensure_http_client( | |
| self.base_url, | |
| { | |
| "Authorization": f"Bearer {self.api_key}", | |
| "content-type": "application/json", | |
| }, | |
| ) | |
| payload = { | |
| "model": request.model, | |
| "messages": [{"role": "user", "content": request.prompt}], | |
| "max_tokens": request.max_tokens, | |
| "temperature": request.temperature, | |
| "top_p": request.top_p, | |
| "stream": True, | |
| } | |
| async with self._http_client.stream("POST", "/chat/completions", json=payload) as response: | |
| response.raise_for_status() | |
| async for line in response.aiter_lines(): | |
| if line.startswith("data: "): | |
| data = line[6:].strip() | |
| if data == "[DONE]": | |
| break | |
| try: | |
| chunk = json.loads(data) | |
| delta = chunk.get("choices", [{}])[0].get("delta", {}) | |
| content = delta.get("content", "") | |
| if content: | |
| yield content | |
| except json.JSONDecodeError: | |
| continue | |
| def _count_tokens(self, text: str) -> int: | |
| """Use OpenAI's tiktoken for accurate counting.""" | |
| if self._token_counter: | |
| return self._token_counter.count_tokens(text) | |
| # Fallback | |
| if _TIKTOKEN_AVAILABLE: | |
| try: | |
| enc = tiktoken.encoding_for_model("gpt-4") | |
| return len(enc.encode(text)) | |
| except KeyError: | |
| enc = tiktoken.get_encoding("cl100k_base") | |
| return len(enc.encode(text)) | |
| return len(text) // 4 | |
| class GenericAPIClient(BaseLLMClient): | |
| """ | |
| Generic client for providers with OpenAI-compatible API format. | |
| Supports DeepSeek, Qwen, Kimi, OpenRouter, and custom endpoints. | |
| """ | |
| PROVIDER = LLMProvider.CUSTOM | |
| DEFAULT_MODEL = "" | |
| def __init__( | |
| self, | |
| provider: LLMProvider, | |
| api_key: str, | |
| base_url: str, | |
| default_model: str, | |
| **kwargs, | |
| ): | |
| self.PROVIDER = provider | |
| self.DEFAULT_MODEL = default_model | |
| super().__init__(api_key, base_url=base_url, **kwargs) | |
| self._init_token_counter(default_model) | |
| def _get_pricing_key(self, model: str) -> str: | |
| prefix = { | |
| LLMProvider.DEEPSEEK: "deepseek", | |
| LLMProvider.QWEN: "qwen", | |
| LLMProvider.KIMI: "moonshot", | |
| LLMProvider.OPENROUTER: "openrouter", | |
| }.get(self.PROVIDER, "custom") | |
| return f"{prefix}/{model}" | |
| async def _generate_internal(self, request: LLMRequest) -> LLMResponse: | |
| self._http_client = self._ensure_http_client( | |
| self.base_url, | |
| { | |
| "Authorization": f"Bearer {self.api_key}", | |
| "content-type": "application/json", | |
| }, | |
| ) | |
| payload = { | |
| "model": request.model, | |
| "messages": [{"role": "user", "content": request.prompt}], | |
| "max_tokens": request.max_tokens, | |
| "temperature": request.temperature, | |
| "top_p": request.top_p, | |
| } | |
| if request.stop_sequences: | |
| payload["stop"] = request.stop_sequences | |
| response = await self._http_client.post("/chat/completions", json=payload) | |
| response.raise_for_status() | |
| data = response.json() | |
| if not data.get("choices") or not data["choices"][0].get("message"): | |
| raise LLMProviderError( | |
| message=f"{self.PROVIDER.value} returned empty response", | |
| error_type="empty_response", | |
| provider=self.PROVIDER, | |
| model=request.model, | |
| retryable=True, | |
| ) | |
| text = data["choices"][0]["message"]["content"] | |
| return LLMResponse( | |
| text=text, | |
| provider=self.PROVIDER, | |
| model=request.model, | |
| input_tokens=0, | |
| output_tokens=0, | |
| latency_ms=0, | |
| cost_usd=0, | |
| metadata={ | |
| "finish_reason": data["choices"][0].get("finish_reason"), | |
| "usage": data.get("usage", {}), | |
| }, | |
| ) | |
| async def _generate_stream_internal( | |
| self, request: LLMRequest | |
| ) -> AsyncGenerator[str, None]: | |
| """Streaming generation for OpenAI-compatible APIs.""" | |
| self._http_client = self._ensure_http_client( | |
| self.base_url, | |
| { | |
| "Authorization": f"Bearer {self.api_key}", | |
| "content-type": "application/json", | |
| }, | |
| ) | |
| payload = { | |
| "model": request.model, | |
| "messages": [{"role": "user", "content": request.prompt}], | |
| "max_tokens": request.max_tokens, | |
| "temperature": request.temperature, | |
| "top_p": request.top_p, | |
| "stream": True, | |
| } | |
| async with self._http_client.stream("POST", "/chat/completions", json=payload) as response: | |
| response.raise_for_status() | |
| async for line in response.aiter_lines(): | |
| if line.startswith("data: "): | |
| data = line[6:].strip() | |
| if data == "[DONE]": | |
| break | |
| try: | |
| chunk = json.loads(data) | |
| delta = chunk.get("choices", [{}])[0].get("delta", {}) | |
| content = delta.get("content", "") | |
| if content: | |
| yield content | |
| except json.JSONDecodeError: | |
| continue | |
| # ============================================================================= | |
| # FACTORY & ORCHESTRATOR | |
| # ============================================================================= | |
| class LLMClientFactory: | |
| """Factory for creating configured LLM clients.""" | |
| _clients: Dict[str, BaseLLMClient] = {} | |
| def create( | |
| cls, | |
| provider: LLMProvider, | |
| model: Optional[str] = None, | |
| api_key: Optional[str] = None, | |
| base_url: Optional[str] = None, | |
| retry_config: Optional[RetryConfig] = None, | |
| timeout_config: Optional[TimeoutConfig] = None, | |
| rate_limit_config: Optional[RateLimitConfig] = None, | |
| circuit_breaker: Optional[CircuitBreaker] = None, | |
| max_concurrent_requests: int = 10, | |
| ) -> BaseLLMClient: | |
| """Create a configured client for the specified provider.""" | |
| if api_key is None: | |
| env_key_map = { | |
| LLMProvider.GEMINI: "GEMINI_API_KEY", | |
| LLMProvider.CLAUDE: "ANTHROPIC_API_KEY", | |
| LLMProvider.OPENAI: "OPENAI_API_KEY", | |
| LLMProvider.DEEPSEEK: "DEEPSEEK_API_KEY", | |
| LLMProvider.QWEN: "QWEN_API_KEY", | |
| LLMProvider.KIMI: "KIMI_API_KEY", | |
| LLMProvider.OPENROUTER: "OPENROUTER_API_KEY", | |
| } | |
| env_var = env_key_map.get(provider) | |
| if env_var: | |
| api_key = os.getenv(env_var) | |
| if not api_key: | |
| raise ValueError( | |
| f"API key required for {provider.value}. " | |
| f"Set {env_key_map.get(provider, 'API_KEY')} env var or pass api_key parameter." | |
| ) | |
| default_model_map = { | |
| LLMProvider.GEMINI: "gemini-pro", | |
| LLMProvider.CLAUDE: "claude-3-sonnet-20240229", | |
| LLMProvider.OPENAI: "gpt-4", | |
| LLMProvider.DEEPSEEK: "deepseek-chat", | |
| LLMProvider.QWEN: "qwen-max", | |
| LLMProvider.KIMI: "kimi", | |
| LLMProvider.OPENROUTER: "openai/gpt-4", | |
| } | |
| model = model or default_model_map.get(provider, "") | |
| config_kwargs = { | |
| "retry_config": retry_config or RetryConfig(), | |
| "timeout_config": timeout_config or TimeoutConfig(), | |
| "rate_limit_config": rate_limit_config, | |
| "circuit_breaker": circuit_breaker, | |
| "max_concurrent_requests": max_concurrent_requests, | |
| } | |
| if provider == LLMProvider.GEMINI: | |
| return GeminiClient(api_key=api_key, model=model, **config_kwargs) | |
| elif provider == LLMProvider.CLAUDE: | |
| return ClaudeClient(api_key=api_key, model=model, **config_kwargs) | |
| elif provider == LLMProvider.OPENAI: | |
| return OpenAIClient(api_key=api_key, model=model, **config_kwargs) | |
| elif provider in (LLMProvider.DEEPSEEK, LLMProvider.QWEN, LLMProvider.KIMI, LLMProvider.OPENROUTER): | |
| base_urls = { | |
| LLMProvider.DEEPSEEK: "https://api.deepseek.com/v1", | |
| LLMProvider.QWEN: "https://dashscope.aliyuncs.com/compatible-mode/v1", | |
| LLMProvider.KIMI: "https://api.moonshot.cn/v1", | |
| LLMProvider.OPENROUTER: "https://openrouter.ai/api/v1", | |
| } | |
| return GenericAPIClient( | |
| provider=provider, | |
| api_key=api_key, | |
| base_url=base_url or base_urls[provider], | |
| default_model=model, | |
| **config_kwargs, | |
| ) | |
| elif provider == LLMProvider.CUSTOM: | |
| if not base_url: | |
| raise ValueError("base_url required for CUSTOM provider") | |
| return GenericAPIClient( | |
| provider=provider, | |
| api_key=api_key, | |
| base_url=base_url, | |
| default_model=model, | |
| **config_kwargs, | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported provider: {provider}") | |
| def get_or_create(cls, provider: LLMProvider, **kwargs) -> BaseLLMClient: | |
| """Get existing client or create new one (singleton per provider config).""" | |
| key = f"{provider.value}:{kwargs.get('model', 'default')}" | |
| if key not in cls._clients: | |
| cls._clients[key] = cls.create(provider, **kwargs) | |
| return cls._clients[key] | |
| async def close_all(cls) -> None: | |
| """Close all managed clients.""" | |
| for client in cls._clients.values(): | |
| await client.close() | |
| cls._clients.clear() | |
| class LLMOrchestrator: | |
| """High-level orchestrator for multi-provider LLM inference.""" | |
| def __init__( | |
| self, | |
| default_provider: LLMProvider = LLMProvider.OPENAI, | |
| fallback_providers: Optional[List[LLMProvider]] = None, | |
| load_balance: bool = False, | |
| max_concurrent_requests: int = 10, | |
| default_model: Optional[str] = None, | |
| ): | |
| self.default_provider = default_provider | |
| self.fallback_providers = fallback_providers or [] | |
| self.load_balance = load_balance | |
| self.max_concurrent_requests = max_concurrent_requests | |
| self._provider_index = 0 | |
| self.default_model = default_model or "" | |
| async def generate( | |
| self, | |
| prompt: str, | |
| provider: Optional[LLMProvider] = None, | |
| model: Optional[str] = None, | |
| **request_kwargs, | |
| ) -> LLMResponse: | |
| """Generate text with automatic provider selection and fallback.""" | |
| if provider is None: | |
| if self.load_balance and self.fallback_providers: | |
| providers = [self.default_provider] + self.fallback_providers | |
| provider = providers[self._provider_index % len(providers)] | |
| self._provider_index += 1 | |
| else: | |
| provider = self.default_provider | |
| # Use orchestrator default model when none provided | |
| effective_model = model or self.default_model or "" | |
| request = LLMRequest( | |
| prompt=prompt, | |
| provider=provider, | |
| model=effective_model, | |
| **request_kwargs, | |
| ) | |
| try: | |
| client = LLMClientFactory.get_or_create( | |
| provider, model=request.model, | |
| max_concurrent_requests=self.max_concurrent_requests, | |
| ) | |
| return await client.generate(request) | |
| except LLMProviderError as e: | |
| logger.warning( | |
| f"Primary provider {provider.value} failed: {e.error_type} - {e.message}" | |
| ) | |
| if not e.retryable: | |
| raise | |
| for fallback in self.fallback_providers: | |
| if fallback == provider: | |
| continue | |
| try: | |
| logger.info(f"Trying fallback provider: {fallback.value}") | |
| client = LLMClientFactory.get_or_create( | |
| fallback, model=request.model, | |
| max_concurrent_requests=self.max_concurrent_requests, | |
| ) | |
| response = await client.generate(request) | |
| response.metadata["fallback_from"] = provider.value | |
| return response | |
| except LLMProviderError as e: | |
| logger.warning(f"Fallback provider {fallback.value} failed: {e.error_type}") | |
| continue | |
| raise LLMProviderError( | |
| message="All providers failed to generate response", | |
| error_type="all_providers_failed", | |
| provider=provider, | |
| model=request.model, | |
| retryable=False, | |
| ) | |
| async def generate_stream( | |
| self, | |
| prompt: str, | |
| provider: Optional[LLMProvider] = None, | |
| model: Optional[str] = None, | |
| **request_kwargs, | |
| ) -> AsyncGenerator[Dict[str, Any], None]: | |
| """Generate text with streaming support and automatic provider selection.""" | |
| if provider is None: | |
| provider = self.default_provider | |
| effective_model = model or self.default_model or "" | |
| request = LLMRequest( | |
| prompt=prompt, | |
| provider=provider, | |
| model=effective_model, | |
| **request_kwargs, | |
| ) | |
| client = LLMClientFactory.get_or_create( | |
| provider, model=request.model, | |
| max_concurrent_requests=self.max_concurrent_requests, | |
| ) | |
| async for chunk in client.generate_stream(request): | |
| yield chunk | |
| async def health_check_all(self) -> Dict[str, Dict[str, Any]]: | |
| """Perform health checks on all configured providers.""" | |
| results = {} | |
| for provider in [self.default_provider] + self.fallback_providers: | |
| try: | |
| client = LLMClientFactory.get_or_create(provider) | |
| results[provider.value] = await client.health_check() | |
| except Exception as e: | |
| results[provider.value] = { | |
| "status": "error", | |
| "error": str(e), | |
| } | |
| return results | |
| def get_metrics(self) -> Dict[str, Dict[str, Any]]: | |
| """Get metrics from all configured clients.""" | |
| metrics = {} | |
| for key, client in LLMClientFactory._clients.items(): | |
| metrics[key] = client.get_metrics() | |
| return metrics | |
| async def close(self) -> None: | |
| """Clean up all resources.""" | |
| await LLMClientFactory.close_all() | |
| # ============================================================================= | |
| # EXCEPTIONS | |
| # ============================================================================= | |
| class LLMProviderError(Exception): | |
| """Base exception for LLM provider errors.""" | |
| def __init__( | |
| self, | |
| message: str, | |
| error_type: str, | |
| provider: LLMProvider, | |
| model: Optional[str] = None, | |
| retryable: bool = False, | |
| details: Optional[Dict[str, Any]] = None, | |
| ): | |
| super().__init__(message) | |
| self.message = message | |
| self.error_type = error_type | |
| self.provider = provider | |
| self.model = model | |
| self.retryable = retryable | |
| self.details = details or {} | |
| self.timestamp = datetime.utcnow() | |
| def __str__(self) -> str: | |
| return ( | |
| f"[{self.provider.value}] {self.error_type}: {self.message}" | |
| f"{' (retryable)' if self.retryable else ''}" | |
| ) | |
| def to_dict(self) -> Dict[str, Any]: | |
| """Convert to dictionary for logging/serialization.""" | |
| return { | |
| "message": self.message, | |
| "error_type": self.error_type, | |
| "provider": self.provider.value, | |
| "model": self.model, | |
| "retryable": self.retryable, | |
| "details": self.details, | |
| "timestamp": self.timestamp.isoformat(), | |
| } | |
| # ============================================================================= | |
| # UTILITIES | |
| # ============================================================================= | |
| def validate_prompt(prompt: str, max_length: int = 100000) -> str: | |
| """Validate and sanitize prompt input.""" | |
| if not prompt or not isinstance(prompt, str): | |
| raise ValueError("prompt must be a non-empty string") | |
| prompt = prompt.strip() | |
| if not prompt: | |
| raise ValueError("prompt cannot be empty after stripping") | |
| if len(prompt) > max_length: | |
| raise ValueError( | |
| f"prompt exceeds maximum length of {max_length} characters " | |
| f"(got {len(prompt)})" | |
| ) | |
| return prompt | |
| def estimate_cost( | |
| input_tokens: int, | |
| output_tokens: int, | |
| provider: LLMProvider, | |
| model: str, | |
| ) -> float: | |
| """Estimate generation cost without making a request.""" | |
| pricing_key = f"{provider.value}/{model}" | |
| pricing = PROVIDER_PRICING.get( | |
| pricing_key, | |
| PROVIDER_PRICING.get(f"{provider.value}/*", ProviderPricing(0, 0)) | |
| ) | |
| return ( | |
| input_tokens * pricing.input_price + | |
| output_tokens * pricing.output_price | |
| ) / 1000 | |