""" LLM client orchestrator for multi-provider inference. This module provides a unified, scalable interface for interacting with multiple Large Language Model providers including: - Google Gemini - Anthropic Claude - OpenAI (GPT series) - DeepSeek - Qwen (Alibaba) - Kimi (Moonshot AI) - OpenRouter (aggregator) - Custom HTTP endpoints via generic client Performance Characteristics --------------------------- - Request latency: Provider-dependent (typically 200ms-5s for inference) - Retry overhead: O(log n) for exponential backoff with max_attempts=3-5 - Token counting: O(L) where L = character length via provider-specific tokenizer - Memory: O(1) per request + O(P) for provider configs, P = provider count Thread Safety ------------- - All client methods are async and reentrant - Rate limiters, circuit breakers, and concurrency semaphores use asyncio.Lock - No shared mutable state beyond configured provider instances - Safe for concurrent use in FastAPI/Starlette applications Author: IntelliDeep Labs Team License: BSL 1.1 """ from __future__ import annotations import asyncio import json import logging import os import random import time from abc import ABC, abstractmethod from dataclasses import dataclass, field from datetime import datetime from enum import Enum, auto from typing import Any, AsyncGenerator, Callable, Dict, List, Optional import httpx from pydantic import field_validator from nlproxy.utils.constants import PROVIDER_PRICING try: from google import genai from google.genai import types _GEMINI_AVAILABLE = True except ImportError: _GEMINI_AVAILABLE = False genai = None types = None try: import tiktoken _TIKTOKEN_AVAILABLE = True except ImportError: _TIKTOKEN_AVAILABLE = False tiktoken = None # type: ignore # Configure logger with project-standard format logger = logging.getLogger(__name__) # ============================================================================= # CONFIGURATION & PRICING # ============================================================================= @dataclass(frozen=True) class ProviderPricing: """ Pricing configuration for an LLM provider. Attributes ---------- input_price : float Cost per 1000 input tokens in USD. output_price : float Cost per 1000 output tokens in USD. currency : str Currency code (default: "USD"). """ input_price: float output_price: float currency: str = "USD" # Provider pricing is imported from shared constants; env vars may override values at import time. @dataclass class RetryConfig: """ Configuration for retry behavior with exponential backoff. Attributes ---------- max_attempts : int Maximum number of retry attempts (default: 3). base_delay : float Initial delay in seconds before first retry (default: 1.0). max_delay : float Maximum delay cap in seconds (default: 30.0). exponential_base : float Base for exponential backoff calculation (default: 2.0). jitter : bool Whether to add random jitter to delays (default: True). retryable_exceptions : Tuple[Type[Exception], ...] Exception types that trigger a retry. """ max_attempts: int = 3 base_delay: float = 1.0 max_delay: float = 30.0 exponential_base: float = 2.0 jitter: bool = True retryable_exceptions: tuple = ( httpx.TimeoutException, httpx.NetworkError, httpx.HTTPStatusError, ConnectionError, asyncio.TimeoutError, ) @dataclass class TimeoutConfig: """ Timeout configuration for LLM requests. Attributes ---------- connect : float Connection timeout in seconds (default: 10.0). read : float Read/response timeout in seconds (default: 60.0). write : float Write/request timeout in seconds (default: 10.0). pool : float Connection pool timeout in seconds (default: 10.0). """ connect: float = 10.0 read: float = 60.0 write: float = 10.0 pool: float = 10.0 @dataclass class RateLimitConfig: """ Rate limiting configuration using token bucket algorithm. Attributes ---------- requests_per_minute : Optional[int] Maximum requests allowed per minute. None = unlimited. tokens_per_request : int Estimated token cost per request for rate calculation. bucket_capacity : int Maximum tokens in the bucket (burst capacity). refill_rate : float Tokens added per second (sustained rate). """ requests_per_minute: Optional[int] = None tokens_per_request: int = 1000 bucket_capacity: int = 10000 refill_rate: float = 100.0 # ============================================================================= # DATA MODELS # ============================================================================= class LLMProvider(str, Enum): """ Supported LLM providers. Values correspond to provider identifiers used in configuration. """ GEMINI = "gemini" CLAUDE = "claude" OPENAI = "openai" DEEPSEEK = "deepseek" QWEN = "qwen" KIMI = "kimi" OPENROUTER = "openrouter" CUSTOM = "custom" class RequestStatus(str, Enum): """Request lifecycle status.""" PENDING = "pending" IN_PROGRESS = "in_progress" COMPLETED = "completed" FAILED = "failed" RETRYING = "retrying" TIMEOUT = "timeout" RATE_LIMITED = "rate_limited" CIRCUIT_OPEN = "circuit_open" @dataclass class LLMRequest: """ Unified request model for all LLM providers. Attributes ---------- prompt : str Input text to send to the LLM (required). provider : LLMProvider Target provider for inference. model : str Model identifier (e.g., "gpt-4", "claude-3-opus"). max_tokens : int Maximum tokens to generate (default: 512). temperature : float Sampling temperature ∈ [0.0, 2.0] (default: 0.7). top_p : float Nucleus sampling threshold ∈ [0.0, 1.0] (default: 0.95). top_k : int Top-k sampling parameter (default: 40). stop_sequences : Optional[List[str]] Sequences that trigger generation stop. metadata : Optional[Dict[str, Any]] Additional metadata for logging/tracing. """ prompt: str provider: LLMProvider model: str max_tokens: int = 512 temperature: float = 0.7 top_p: float = 0.95 top_k: int = 40 stop_sequences: Optional[List[str]] = None metadata: Optional[Dict[str, Any]] = None @field_validator("prompt") @classmethod def prompt_must_not_be_empty(cls, v: str) -> str: if not v or not v.strip(): raise ValueError("prompt must not be empty") return v.strip() @field_validator("temperature") @classmethod def temperature_in_range(cls, v: float) -> float: if not 0.0 <= v <= 2.0: raise ValueError("temperature must be in [0.0, 2.0]") return v @field_validator("top_p") @classmethod def top_p_in_range(cls, v: float) -> float: if not 0.0 <= v <= 1.0: raise ValueError("top_p must be in [0.0, 1.0]") return v @dataclass class LLMResponse: """ Unified response model from all LLM providers. Attributes ---------- text : str Generated text output. provider : LLMProvider Provider that generated the response. model : str Model that generated the response. input_tokens : int Number of tokens in the input prompt. output_tokens : int Number of tokens in the generated output. latency_ms : float End-to-end latency in milliseconds. cost_usd : float Estimated cost in USD based on provider pricing. metadata : Dict[str, Any] Provider-specific metadata (finish_reason, logprobs, etc.). request_id : str Unique identifier for tracing/logging. timestamp : datetime Response generation timestamp. """ text: str provider: LLMProvider model: str input_tokens: int output_tokens: int latency_ms: float cost_usd: float metadata: Dict[str, Any] = field(default_factory=dict) request_id: str = "" timestamp: datetime = field(default_factory=datetime.utcnow) @dataclass class LLMError: """ Structured error information for failed requests. Attributes ---------- message : str Human-readable error description. error_type : str Categorized error type (timeout, rate_limit, auth, etc.). provider : LLMProvider Provider where error occurred. model : Optional[str] Model involved (if applicable). retryable : bool Whether the error is transient and retryable. details : Optional[Dict[str, Any]] Additional error context (status code, response body, etc.). timestamp : datetime Error occurrence timestamp. """ message: str error_type: str provider: LLMProvider model: Optional[str] = None retryable: bool = False details: Optional[Dict[str, Any]] = None timestamp: datetime = field(default_factory=datetime.utcnow) # ============================================================================= # CIRCUIT BREAKER (Improved: distinguishes retryable vs non-retryable) # ============================================================================= class CircuitState(Enum): """Circuit breaker states.""" CLOSED = auto() OPEN = auto() HALF_OPEN = auto() class CircuitBreaker: """ Circuit breaker for fault tolerance in LLM provider calls. IMPROVEMENT: Distinguishes between retryable and non-retryable errors. Only non-retryable errors or persistent failures after retries increment the failure count that can trip the circuit. State Transitions: ------------------ CLOSED → OPEN: When non-retryable failures ≥ threshold within window OPEN → HALF_OPEN: After recovery_timeout expires HALF_OPEN → CLOSED: When success_count ≥ success_threshold HALF_OPEN → OPEN: On any non-retryable failure during testing """ def __init__( self, failure_threshold: int = 5, recovery_timeout: float = 30.0, success_threshold: int = 3, window_seconds: float = 60.0, ): self.failure_threshold = failure_threshold self.recovery_timeout = recovery_timeout self.success_threshold = success_threshold self.window_seconds = window_seconds self._state = CircuitState.CLOSED self._failure_count = 0 self._success_count = 0 self._last_failure_time: Optional[float] = None self._lock = asyncio.Lock() @property def state(self) -> CircuitState: """Current circuit state with automatic OPEN → HALF_OPEN transition.""" if self._state == CircuitState.OPEN: if self._last_failure_time: elapsed = time.time() - self._last_failure_time if elapsed >= self.recovery_timeout: return CircuitState.HALF_OPEN return self._state async def can_execute(self) -> bool: """Check if a request can proceed through the circuit.""" async with self._lock: current_state = self.state if current_state == CircuitState.CLOSED: return True elif current_state == CircuitState.HALF_OPEN: return True else: return False async def record_success(self) -> None: """Record a successful request.""" async with self._lock: if self._state == CircuitState.HALF_OPEN: self._success_count += 1 if self._success_count >= self.success_threshold: self._reset() logger.info("Circuit breaker CLOSED after successful recovery") elif self._state == CircuitState.CLOSED: self._failure_count = 0 async def record_failure(self, retryable: bool = False) -> None: """ Record a failed request. IMPROVEMENT: Only non-retryable errors increment the failure count that can trip the circuit breaker. Retryable errors are logged but don't affect circuit state unless they persist after max retries. Parameters ---------- retryable : bool Whether the error is transient and retryable. """ async with self._lock: now = time.time() # Remove old failures outside the window if self._last_failure_time: if now - self._last_failure_time > self.window_seconds: self._failure_count = 0 # Only increment failure count for non-retryable errors if not retryable: self._failure_count += 1 self._last_failure_time = now if self._state == CircuitState.HALF_OPEN: self._state = CircuitState.OPEN logger.warning("Circuit breaker re-OPENED after non-retryable failure in HALF_OPEN") elif self._state == CircuitState.CLOSED: if self._failure_count >= self.failure_threshold: self._state = CircuitState.OPEN logger.warning( f"Circuit breaker OPENED: {self._failure_count} non-retryable failures " f"in {self.window_seconds}s window" ) else: logger.debug(f"Retryable error recorded; circuit breaker state unchanged: {self._state.name}") def _reset(self) -> None: """Reset circuit breaker to initial closed state.""" self._state = CircuitState.CLOSED self._failure_count = 0 self._success_count = 0 self._last_failure_time = None def get_stats(self) -> Dict[str, Any]: """Return circuit breaker statistics for monitoring.""" return { "state": self.state.name, "failure_count": self._failure_count, "success_count": self._success_count, "last_failure_time": self._last_failure_time, "failure_threshold": self.failure_threshold, "recovery_timeout": self.recovery_timeout, } # ============================================================================= # RATE LIMITER + CONCURRENCY SEMAPHORE # ============================================================================= class TokenBucket: """ Token bucket rate limiter for controlling request throughput. Allows burst traffic up to bucket_capacity while sustaining refill_rate tokens per second long-term. """ def __init__(self, capacity: float, refill_rate: float): self.capacity = capacity self.refill_rate = refill_rate self._tokens = capacity self._last_refill = time.time() self._lock = asyncio.Lock() async def acquire(self, tokens: float = 1.0, timeout: Optional[float] = None) -> bool: """Attempt to acquire tokens from the bucket.""" start_time = time.time() while True: async with self._lock: now = time.time() elapsed = now - self._last_refill self._tokens = min( self.capacity, self._tokens + self.refill_rate * elapsed ) self._last_refill = now if self._tokens >= tokens: self._tokens -= tokens return True needed = tokens - self._tokens wait_time = needed / self.refill_rate if timeout is not None: elapsed_total = time.time() - start_time if elapsed_total + wait_time > timeout: return False await asyncio.sleep(min(wait_time, 0.1)) def get_stats(self) -> Dict[str, float]: """Return current bucket statistics.""" return { "available_tokens": self._tokens, "capacity": self.capacity, "refill_rate": self.refill_rate, } # ============================================================================= # PROVIDER-SPECIFIC TOKENIZERS # ============================================================================= class TokenCounter(ABC): """ Abstract base class for provider-specific token counting. IMPROVEMENT: Each provider can implement accurate token counting using their native tokenizer or API endpoint. """ @abstractmethod def count_tokens(self, text: str) -> int: """Count tokens in text using provider-specific method.""" pass class OpenAITokenCounter(TokenCounter): """Token counter for OpenAI models using tiktoken.""" def __init__(self, model: str): if _TIKTOKEN_AVAILABLE: try: self.encoding = tiktoken.encoding_for_model(model) except KeyError: self.encoding = tiktoken.get_encoding("cl100k_base") else: self.encoding = None def count_tokens(self, text: str) -> int: if self.encoding: return len(self.encoding.encode(text)) return len(text) // 4 class ClaudeTokenCounter(TokenCounter): """ Token counter for Anthropic Claude models. Uses Anthropic's official tokenizer when available, falls back to tiktoken with cl100k_base encoding (close approximation). """ def __init__(self, model: str): self.model = model # Try to import Anthropic's tokenizer try: from anthropic import Anthropic self._client = Anthropic(api_key="dummy") # Tokenizer doesn't need valid key self._has_native = True except (ImportError, Exception): self._has_native = False if _TIKTOKEN_AVAILABLE: self.encoding = tiktoken.get_encoding("cl100k_base") else: self.encoding = None def count_tokens(self, text: str) -> int: if self._has_native: try: return self._client.count_tokens(text) except Exception: pass if self.encoding: return len(self.encoding.encode(text)) return len(text) // 4 class GeminiTokenCounter(TokenCounter): """ Token counter for Google Gemini models. Uses google.generativeai.count_tokens when available. """ def __init__(self, model: str): self.model = model self._has_native = _GEMINI_AVAILABLE def count_tokens(self, text: str) -> int: if self._has_native and genai: try: return genai.count_tokens(text).total_tokens except Exception: pass if _TIKTOKEN_AVAILABLE: enc = tiktoken.get_encoding("cl100k_base") return len(enc.encode(text)) return len(text) // 4 class GenericTokenCounter(TokenCounter): """ Fallback token counter for providers without native tokenizer. Uses tiktoken with cl100k_base encoding as best-effort approximation. """ def __init__(self, model: str): if _TIKTOKEN_AVAILABLE: self.encoding = tiktoken.get_encoding("cl100k_base") else: self.encoding = None def count_tokens(self, text: str) -> int: if self.encoding: return len(self.encoding.encode(text)) return len(text) // 4 def get_token_counter(provider: LLMProvider, model: str) -> TokenCounter: """ Factory function to get appropriate token counter for provider. Parameters ---------- provider : LLMProvider Target provider. model : str Model identifier. Returns ------- TokenCounter Provider-specific token counter instance. """ if provider == LLMProvider.OPENAI: return OpenAITokenCounter(model) elif provider == LLMProvider.CLAUDE: return ClaudeTokenCounter(model) elif provider == LLMProvider.GEMINI: return GeminiTokenCounter(model) else: return GenericTokenCounter(model) # ============================================================================= # ABSTRACT PROVIDER CLIENT # ============================================================================= class BaseLLMClient(ABC): """ Abstract base class for LLM provider clients. Defines the interface that all provider implementations must follow. """ PROVIDER: LLMProvider DEFAULT_MODEL: str def __init__( self, api_key: str, model: Optional[str] = None, retry_config: RetryConfig = RetryConfig(), timeout_config: TimeoutConfig = TimeoutConfig(), rate_limit_config: Optional[RateLimitConfig] = None, circuit_breaker: Optional[CircuitBreaker] = None, base_url: Optional[str] = None, max_concurrent_requests: int = 10, ): self.api_key = api_key self.model = model self.retry_config = retry_config self.timeout_config = timeout_config self.rate_limit_config = rate_limit_config self.circuit_breaker = circuit_breaker or CircuitBreaker() self.base_url = base_url self.max_concurrent_requests = max_concurrent_requests # Initialize rate limiter if configured self._rate_limiter: Optional[TokenBucket] = None if rate_limit_config and rate_limit_config.requests_per_minute: tokens_per_sec = ( rate_limit_config.requests_per_minute * rate_limit_config.tokens_per_request / 60.0 ) self._rate_limiter = TokenBucket( capacity=rate_limit_config.bucket_capacity, refill_rate=tokens_per_sec ) # IMPROVEMENT: Semaphore for concurrency limiting per provider self._concurrency_semaphore = asyncio.Semaphore(max_concurrent_requests) # HTTP client for providers using REST API self._http_client: Optional[httpx.AsyncClient] = None # IMPROVEMENT: Provider-specific token counter self._token_counter: Optional[TokenCounter] = None # Metrics tracking self._request_count = 0 self._error_count = 0 self._total_latency_ms = 0.0 @abstractmethod async def _generate_internal(self, request: LLMRequest) -> LLMResponse: """Internal generation logic specific to the provider.""" pass @abstractmethod async def _generate_stream_internal( self, request: LLMRequest ) -> AsyncGenerator[str, None]: """Streaming generation logic specific to the provider.""" pass def _init_token_counter(self, model: str) -> None: """Initialize provider-specific token counter.""" self._token_counter = get_token_counter(self.PROVIDER, model) def _count_tokens(self, text: str) -> int: """Count tokens using provider-specific method.""" if self._token_counter: return self._token_counter.count_tokens(text) # Fallback if _TIKTOKEN_AVAILABLE: enc = tiktoken.get_encoding("cl100k_base") return len(enc.encode(text)) return len(text) // 4 def _ensure_http_client( self, base_url: str, headers: Dict[str, str], ) -> httpx.AsyncClient: """Return a single shared AsyncClient instance for this provider client.""" if self._http_client is None: self._http_client = httpx.AsyncClient( base_url=base_url, timeout=httpx.Timeout( connect=self.timeout_config.connect, read=self.timeout_config.read, write=self.timeout_config.write, pool=self.timeout_config.pool, ), headers=headers, ) return self._http_client @abstractmethod def _get_pricing_key(self, model: str) -> str: """Get pricing key for cost calculation.""" pass async def _apply_rate_limit(self, tokens: float = 1.0) -> None: """Apply rate limiting if configured.""" if self._rate_limiter: acquired = await self._rate_limiter.acquire( tokens=tokens, timeout=30.0 ) if not acquired: raise LLMProviderError( message="Rate limit timeout: could not acquire tokens", error_type="rate_limit", provider=self.PROVIDER, retryable=True, ) async def _acquire_concurrency_slot(self) -> None: """Acquire a slot from the concurrency semaphore.""" await self._concurrency_semaphore.acquire() def _release_concurrency_slot(self) -> None: """Release a slot from the concurrency semaphore.""" self._concurrency_semaphore.release() async def _with_retry( self, operation: Callable[[], Any], operation_name: str, ) -> Any: """Execute operation with retry logic and exponential backoff.""" last_error: Optional[Exception] = None last_retryable = True # Track if last error was retryable for attempt in range(1, self.retry_config.max_attempts + 1): try: logger.debug( f"[{self.PROVIDER.value}] {operation_name} attempt {attempt}/" f"{self.retry_config.max_attempts}" ) return await operation() except self.retry_config.retryable_exceptions as e: last_error = e last_retryable = True logger.warning( f"[{self.PROVIDER.value}] {operation_name} failed (attempt {attempt}): " f"{type(e).__name__}: {e}" ) if attempt < self.retry_config.max_attempts: delay = min( self.retry_config.base_delay * (self.retry_config.exponential_base ** (attempt - 1)), self.retry_config.max_delay ) if self.retry_config.jitter: jitter = random.uniform(0, 0.1 * self.retry_config.base_delay) delay += jitter logger.info(f"[{self.PROVIDER.value}] Retrying in {delay:.2f}s...") await asyncio.sleep(delay) else: logger.error( f"[{self.PROVIDER.value}] {operation_name} failed after " f"{self.retry_config.max_attempts} attempts" ) except Exception as e: last_error = e last_retryable = False logger.error( f"[{self.PROVIDER.value}] {operation_name} failed with " f"non-retryable error: {type(e).__name__}: {e}" ) raise LLMProviderError( message=str(e), error_type=type(e).__name__.lower(), provider=self.PROVIDER, retryable=False, details={"original_exception": str(e)}, ) # Exhausted retries - record failure for circuit breaker await self.circuit_breaker.record_failure(retryable=last_retryable) raise LLMProviderError( message=f"{operation_name} failed after {self.retry_config.max_attempts} attempts", error_type="max_retries_exceeded", provider=self.PROVIDER, retryable=False, details={"last_error": str(last_error)}, ) async def generate(self, request: LLMRequest) -> LLMResponse: """ Generate text from LLM with full error handling and metrics. IMPROVEMENTS: - Uses provider-specific token counter for accurate cost estimation - Concurrency limiting via semaphore - Circuit breaker only trips on non-retryable errors """ request_id = f"{self.PROVIDER.value}-{int(time.time() * 1000)}-{random.randint(1000, 9999)}" start_time = time.time() logger.info( f"[{request_id}] Starting {self.PROVIDER.value} request: " f"model={request.model}, prompt_length={len(request.prompt)}" ) try: # Check circuit breaker if not await self.circuit_breaker.can_execute(): raise LLMProviderError( message="Circuit breaker is OPEN; provider unavailable", error_type="circuit_open", provider=self.PROVIDER, model=request.model, retryable=True, ) # Acquire concurrency slot await self._acquire_concurrency_slot() try: # Apply rate limiting input_tokens = self._count_tokens(request.prompt) await self._apply_rate_limit(tokens=input_tokens / 1000) # Execute generation with retry response = await self._with_retry( lambda: self._generate_internal(request), operation_name=f"generate({request.model})" ) # Record success await self.circuit_breaker.record_success() # Update metrics latency_ms = (time.time() - start_time) * 1000 self._request_count += 1 self._total_latency_ms += latency_ms # Add request metadata to response response.request_id = request_id response.input_tokens = input_tokens response.output_tokens = self._count_tokens(response.text) response.latency_ms = latency_ms # Calculate cost pricing_key = self._get_pricing_key(request.model) pricing = PROVIDER_PRICING.get( pricing_key, PROVIDER_PRICING.get("custom/*", {"input": 0.0, "output": 0.0}) ) response.cost_usd = ( response.input_tokens * pricing.get("input", 0.0) + response.output_tokens * pricing.get("output", 0.0) ) / 1000 logger.info( f"[{request_id}] {self.PROVIDER.value} success: " f"latency={latency_ms:.0f}ms, tokens_in={response.input_tokens}, " f"tokens_out={response.output_tokens}, cost=${response.cost_usd:.4f}" ) return response finally: self._release_concurrency_slot() except LLMProviderError as e: await self.circuit_breaker.record_failure(retryable=e.retryable) self._error_count += 1 logger.error( f"[{request_id}] {self.PROVIDER.value} failed: " f"{e.error_type} - {e.message}" ) raise except Exception as e: await self.circuit_breaker.record_failure(retryable=False) self._error_count += 1 logger.exception( f"[{request_id}] {self.PROVIDER.value} unexpected error: {e}" ) raise LLMProviderError( message=f"Unexpected error: {str(e)}", error_type="unexpected", provider=self.PROVIDER, model=request.model, retryable=False, ) async def generate_stream( self, request: LLMRequest ) -> AsyncGenerator[Dict[str, Any], None]: """ Generate text with streaming support. Yields partial responses as tokens are generated. Parameters ---------- request : LLMRequest Validated generation request. Yields ------ Dict[str, Any] Streaming chunks with keys: - text: partial/generated text - finished: bool indicating end of stream - metadata: provider-specific streaming metadata """ request_id = f"{self.PROVIDER.value}-stream-{int(time.time() * 1000)}" start_time = time.time() logger.info( f"[{request_id}] Starting {self.PROVIDER.value} streaming request: " f"model={request.model}" ) try: if not await self.circuit_breaker.can_execute(): raise LLMProviderError( message="Circuit breaker is OPEN; provider unavailable", error_type="circuit_open", provider=self.PROVIDER, model=request.model, retryable=True, ) await self._acquire_concurrency_slot() try: input_tokens = self._count_tokens(request.prompt) await self._apply_rate_limit(tokens=input_tokens / 1000) accumulated_text = "" async for chunk in self._generate_stream_internal(request): accumulated_text += chunk yield { "text": chunk, "accumulated": accumulated_text, "finished": False, "metadata": {}, } # Final chunk with completion metadata latency_ms = (time.time() - start_time) * 1000 output_tokens = self._count_tokens(accumulated_text) pricing_key = self._get_pricing_key(request.model) pricing = PROVIDER_PRICING.get( pricing_key, PROVIDER_PRICING.get("custom/*", {"input": 0.0, "output": 0.0}) ) cost_usd = ( input_tokens * pricing.get("input", 0.0) + output_tokens * pricing.get("output", 0.0) ) / 1000 yield { "text": "", "accumulated": accumulated_text, "finished": True, "metadata": { "input_tokens": input_tokens, "output_tokens": output_tokens, "latency_ms": latency_ms, "cost_usd": cost_usd, "request_id": request_id, }, } await self.circuit_breaker.record_success() self._request_count += 1 self._total_latency_ms += latency_ms logger.info( f"[{request_id}] {self.PROVIDER.value} streaming complete: " f"latency={latency_ms:.0f}ms, tokens_out={output_tokens}" ) finally: self._release_concurrency_slot() except LLMProviderError as e: await self.circuit_breaker.record_failure(retryable=e.retryable) self._error_count += 1 logger.error( f"[{request_id}] {self.PROVIDER.value} streaming failed: " f"{e.error_type} - {e.message}" ) raise except Exception as e: await self.circuit_breaker.record_failure(retryable=False) self._error_count += 1 logger.exception( f"[{request_id}] {self.PROVIDER.value} streaming unexpected error: {e}" ) raise LLMProviderError( message=f"Unexpected error: {str(e)}", error_type="unexpected", provider=self.PROVIDER, model=request.model, retryable=False, ) def get_metrics(self) -> Dict[str, Any]: """Return client metrics for monitoring.""" avg_latency = ( self._total_latency_ms / self._request_count if self._request_count > 0 else 0.0 ) error_rate = ( self._error_count / self._request_count if self._request_count > 0 else 0.0 ) return { "provider": self.PROVIDER.value, "request_count": self._request_count, "error_count": self._error_count, "error_rate": error_rate, "avg_latency_ms": avg_latency, "circuit_breaker": self.circuit_breaker.get_stats(), "rate_limiter": ( self._rate_limiter.get_stats() if self._rate_limiter else None ), "concurrency": { "max_concurrent": self.max_concurrent_requests, "available_slots": self._concurrency_semaphore._value, }, } async def health_check(self) -> Dict[str, Any]: """ Perform a lightweight health check for the provider. IMPROVEMENT: Uses more generous timeout (20s) and minimal prompt to reduce false negatives during high load. """ try: test_request = LLMRequest( prompt="ping", provider=self.PROVIDER, model=self.DEFAULT_MODEL, max_tokens=5, # Minimal tokens for faster response temperature=0.0, ) start = time.time() # IMPROVEMENT: More generous timeout for health checks response = await asyncio.wait_for( self._generate_internal(test_request), timeout=20.0 # Increased from 10s to 20s ) latency_ms = (time.time() - start) * 1000 return { "status": "healthy", "latency_ms": latency_ms, "provider": self.PROVIDER.value, "model": self.DEFAULT_MODEL, "timestamp": datetime.utcnow().isoformat(), } except asyncio.TimeoutError: return { "status": "timeout", "provider": self.PROVIDER.value, "error": "Health check timed out (20s limit)", } except Exception as e: return { "status": "unhealthy", "provider": self.PROVIDER.value, "error": str(e), "error_type": type(e).__name__, } async def close(self) -> None: """Clean up resources (HTTP connections, etc.).""" if self._http_client: await self._http_client.aclose() self._http_client = None # ============================================================================= # PROVIDER IMPLEMENTATIONS # ============================================================================= class GeminiClient(BaseLLMClient): PROVIDER = LLMProvider.GEMINI DEFAULT_MODEL = "gemini-pro" def __init__(self, api_key: str, **kwargs): if not _GEMINI_AVAILABLE: raise ImportError( "google-genai not installed. Install with: pip install google-genai" ) model = kwargs.pop('model', None) super().__init__(api_key, model=model, **kwargs) self._client = genai.Client(api_key=api_key) self._init_token_counter(kwargs.get("model", self.DEFAULT_MODEL)) def _get_pricing_key(self, model: str) -> str: return f"google/{model}" async def _generate_internal(self, request: LLMRequest) -> LLMResponse: try: response = await self._client.aio.models.generate_content( model=request.model, contents=request.prompt, config=types.GenerateContentConfig( max_output_tokens=request.max_tokens, temperature=request.temperature, top_p=request.top_p, top_k=request.top_k, stop_sequences=request.stop_sequences or [], ) ) except Exception as e: raise LLMProviderError( message=str(e), error_type=type(e).__name__.lower(), provider=self.PROVIDER, model=request.model, retryable=True, ) if not response.text: raise LLMProviderError( message="Gemini returned empty response", error_type="empty_response", provider=self.PROVIDER, model=request.model, retryable=True, ) # Extract metadata from the response if available (provider-specific) metadata = {} if hasattr(response, 'usage_metadata'): metadata["usage"] = { "prompt_tokens": response.usage_metadata.prompt_token_count, "candidates_tokens": response.usage_metadata.candidates_token_count, "total_tokens": response.usage_metadata.total_token_count, } return LLMResponse( text=response.text, provider=self.PROVIDER, model=request.model, input_tokens=0, output_tokens=0, latency_ms=0, cost_usd=0, metadata=metadata, ) async def _generate_stream_internal( self, request: LLMRequest ) -> AsyncGenerator[str, None]: """Streaming generation for Gemini using new API.""" async for chunk in await self._client.aio.models.generate_content_stream( model=request.model, contents=request.prompt, config=types.GenerateContentConfig( max_output_tokens=request.max_tokens, temperature=request.temperature, top_p=request.top_p, top_k=request.top_k, ) ): if chunk.text: yield chunk.text def _count_tokens(self, text: str) -> int: if self._token_counter: return self._token_counter.count_tokens(text) if _TIKTOKEN_AVAILABLE: enc = tiktoken.get_encoding("cl100k_base") return len(enc.encode(text)) return len(text) // 4 class ClaudeClient(BaseLLMClient): """Client for Anthropic Claude API.""" PROVIDER = LLMProvider.CLAUDE DEFAULT_MODEL = "claude-3-sonnet-20240229" def __init__(self, api_key: str, **kwargs): model = kwargs.pop('model', None) super().__init__(api_key, model=model, **kwargs) self.base_url = self.base_url or "https://api.anthropic.com" self._init_token_counter(kwargs.get("model", self.DEFAULT_MODEL)) def _get_pricing_key(self, model: str) -> str: return f"anthropic/{model}" async def _generate_internal(self, request: LLMRequest) -> LLMResponse: self._http_client = self._ensure_http_client( self.base_url, { "x-api-key": self.api_key, "anthropic-version": "2023-06-01", "content-type": "application/json", }, ) payload = { "model": request.model, "messages": [{"role": "user", "content": request.prompt}], "max_tokens": request.max_tokens, "temperature": request.temperature, "top_p": request.top_p, "top_k": request.top_k, } if request.stop_sequences: payload["stop_sequences"] = request.stop_sequences response = await self._http_client.post("/v1/messages", json=payload) response.raise_for_status() data = response.json() if not data.get("content"): raise LLMProviderError( message="Claude returned empty response", error_type="empty_response", provider=self.PROVIDER, model=request.model, retryable=True, ) text = data["content"][0]["text"] return LLMResponse( text=text, provider=self.PROVIDER, model=request.model, input_tokens=0, output_tokens=0, latency_ms=0, cost_usd=0, metadata={ "stop_reason": data.get("stop_reason"), "model": data.get("model"), "usage": data.get("usage", {}), }, ) async def _generate_stream_internal( self, request: LLMRequest ) -> AsyncGenerator[str, None]: """Streaming generation for Claude.""" self._http_client = self._ensure_http_client( self.base_url, { "x-api-key": self.api_key, "anthropic-version": "2023-06-01", "content-type": "application/json", }, ) payload = { "model": request.model, "messages": [{"role": "user", "content": request.prompt}], "max_tokens": request.max_tokens, "temperature": request.temperature, "top_p": request.top_p, "stream": True, } async with self._http_client.stream("POST", "/v1/messages", json=payload) as response: response.raise_for_status() async for line in response.aiter_lines(): if line.startswith("data: "): data = line[6:].strip() if data == "[DONE]": break try: chunk = json.loads(data) if chunk.get("type") == "content_block_delta": text = chunk.get("delta", {}).get("text", "") if text: yield text except json.JSONDecodeError: continue def _count_tokens(self, text: str) -> int: """Use Claude's native token counting when available.""" if self._token_counter: return self._token_counter.count_tokens(text) # Fallback to tiktoken if _TIKTOKEN_AVAILABLE: enc = tiktoken.get_encoding("cl100k_base") return len(enc.encode(text)) return len(text) // 4 class OpenAIClient(BaseLLMClient): """Client for OpenAI API.""" PROVIDER = LLMProvider.OPENAI DEFAULT_MODEL = "gpt-4" def __init__(self, api_key: str, **kwargs): model = kwargs.pop('model', None) super().__init__(api_key, model=model, **kwargs) self.base_url = self.base_url or "https://api.openai.com/v1" self._init_token_counter(kwargs.get("model", self.DEFAULT_MODEL)) def _get_pricing_key(self, model: str) -> str: return f"openai/{model}" async def _generate_internal(self, request: LLMRequest) -> LLMResponse: self._http_client = self._ensure_http_client( self.base_url, { "Authorization": f"Bearer {self.api_key}", "content-type": "application/json", }, ) payload = { "model": request.model, "messages": [{"role": "user", "content": request.prompt}], "max_tokens": request.max_tokens, "temperature": request.temperature, "top_p": request.top_p, } if request.stop_sequences: payload["stop"] = request.stop_sequences response = await self._http_client.post("/chat/completions", json=payload) response.raise_for_status() data = response.json() if not data.get("choices") or not data["choices"][0].get("message"): raise LLMProviderError( message="OpenAI returned empty response", error_type="empty_response", provider=self.PROVIDER, model=request.model, retryable=True, ) text = data["choices"][0]["message"]["content"] return LLMResponse( text=text, provider=self.PROVIDER, model=request.model, input_tokens=0, output_tokens=0, latency_ms=0, cost_usd=0, metadata={ "finish_reason": data["choices"][0].get("finish_reason"), "usage": data.get("usage", {}), "system_fingerprint": data.get("system_fingerprint"), }, ) async def _generate_stream_internal( self, request: LLMRequest ) -> AsyncGenerator[str, None]: """Streaming generation for OpenAI.""" self._http_client = self._ensure_http_client( self.base_url, { "Authorization": f"Bearer {self.api_key}", "content-type": "application/json", }, ) payload = { "model": request.model, "messages": [{"role": "user", "content": request.prompt}], "max_tokens": request.max_tokens, "temperature": request.temperature, "top_p": request.top_p, "stream": True, } async with self._http_client.stream("POST", "/chat/completions", json=payload) as response: response.raise_for_status() async for line in response.aiter_lines(): if line.startswith("data: "): data = line[6:].strip() if data == "[DONE]": break try: chunk = json.loads(data) delta = chunk.get("choices", [{}])[0].get("delta", {}) content = delta.get("content", "") if content: yield content except json.JSONDecodeError: continue def _count_tokens(self, text: str) -> int: """Use OpenAI's tiktoken for accurate counting.""" if self._token_counter: return self._token_counter.count_tokens(text) # Fallback if _TIKTOKEN_AVAILABLE: try: enc = tiktoken.encoding_for_model("gpt-4") return len(enc.encode(text)) except KeyError: enc = tiktoken.get_encoding("cl100k_base") return len(enc.encode(text)) return len(text) // 4 class GenericAPIClient(BaseLLMClient): """ Generic client for providers with OpenAI-compatible API format. Supports DeepSeek, Qwen, Kimi, OpenRouter, and custom endpoints. """ PROVIDER = LLMProvider.CUSTOM DEFAULT_MODEL = "" def __init__( self, provider: LLMProvider, api_key: str, base_url: str, default_model: str, **kwargs, ): self.PROVIDER = provider self.DEFAULT_MODEL = default_model super().__init__(api_key, base_url=base_url, **kwargs) self._init_token_counter(default_model) def _get_pricing_key(self, model: str) -> str: prefix = { LLMProvider.DEEPSEEK: "deepseek", LLMProvider.QWEN: "qwen", LLMProvider.KIMI: "moonshot", LLMProvider.OPENROUTER: "openrouter", }.get(self.PROVIDER, "custom") return f"{prefix}/{model}" async def _generate_internal(self, request: LLMRequest) -> LLMResponse: self._http_client = self._ensure_http_client( self.base_url, { "Authorization": f"Bearer {self.api_key}", "content-type": "application/json", }, ) payload = { "model": request.model, "messages": [{"role": "user", "content": request.prompt}], "max_tokens": request.max_tokens, "temperature": request.temperature, "top_p": request.top_p, } if request.stop_sequences: payload["stop"] = request.stop_sequences response = await self._http_client.post("/chat/completions", json=payload) response.raise_for_status() data = response.json() if not data.get("choices") or not data["choices"][0].get("message"): raise LLMProviderError( message=f"{self.PROVIDER.value} returned empty response", error_type="empty_response", provider=self.PROVIDER, model=request.model, retryable=True, ) text = data["choices"][0]["message"]["content"] return LLMResponse( text=text, provider=self.PROVIDER, model=request.model, input_tokens=0, output_tokens=0, latency_ms=0, cost_usd=0, metadata={ "finish_reason": data["choices"][0].get("finish_reason"), "usage": data.get("usage", {}), }, ) async def _generate_stream_internal( self, request: LLMRequest ) -> AsyncGenerator[str, None]: """Streaming generation for OpenAI-compatible APIs.""" self._http_client = self._ensure_http_client( self.base_url, { "Authorization": f"Bearer {self.api_key}", "content-type": "application/json", }, ) payload = { "model": request.model, "messages": [{"role": "user", "content": request.prompt}], "max_tokens": request.max_tokens, "temperature": request.temperature, "top_p": request.top_p, "stream": True, } async with self._http_client.stream("POST", "/chat/completions", json=payload) as response: response.raise_for_status() async for line in response.aiter_lines(): if line.startswith("data: "): data = line[6:].strip() if data == "[DONE]": break try: chunk = json.loads(data) delta = chunk.get("choices", [{}])[0].get("delta", {}) content = delta.get("content", "") if content: yield content except json.JSONDecodeError: continue # ============================================================================= # FACTORY & ORCHESTRATOR # ============================================================================= class LLMClientFactory: """Factory for creating configured LLM clients.""" _clients: Dict[str, BaseLLMClient] = {} @classmethod def create( cls, provider: LLMProvider, model: Optional[str] = None, api_key: Optional[str] = None, base_url: Optional[str] = None, retry_config: Optional[RetryConfig] = None, timeout_config: Optional[TimeoutConfig] = None, rate_limit_config: Optional[RateLimitConfig] = None, circuit_breaker: Optional[CircuitBreaker] = None, max_concurrent_requests: int = 10, ) -> BaseLLMClient: """Create a configured client for the specified provider.""" if api_key is None: env_key_map = { LLMProvider.GEMINI: "GEMINI_API_KEY", LLMProvider.CLAUDE: "ANTHROPIC_API_KEY", LLMProvider.OPENAI: "OPENAI_API_KEY", LLMProvider.DEEPSEEK: "DEEPSEEK_API_KEY", LLMProvider.QWEN: "QWEN_API_KEY", LLMProvider.KIMI: "KIMI_API_KEY", LLMProvider.OPENROUTER: "OPENROUTER_API_KEY", } env_var = env_key_map.get(provider) if env_var: api_key = os.getenv(env_var) if not api_key: raise ValueError( f"API key required for {provider.value}. " f"Set {env_key_map.get(provider, 'API_KEY')} env var or pass api_key parameter." ) default_model_map = { LLMProvider.GEMINI: "gemini-pro", LLMProvider.CLAUDE: "claude-3-sonnet-20240229", LLMProvider.OPENAI: "gpt-4", LLMProvider.DEEPSEEK: "deepseek-chat", LLMProvider.QWEN: "qwen-max", LLMProvider.KIMI: "kimi", LLMProvider.OPENROUTER: "openai/gpt-4", } model = model or default_model_map.get(provider, "") config_kwargs = { "retry_config": retry_config or RetryConfig(), "timeout_config": timeout_config or TimeoutConfig(), "rate_limit_config": rate_limit_config, "circuit_breaker": circuit_breaker, "max_concurrent_requests": max_concurrent_requests, } if provider == LLMProvider.GEMINI: return GeminiClient(api_key=api_key, model=model, **config_kwargs) elif provider == LLMProvider.CLAUDE: return ClaudeClient(api_key=api_key, model=model, **config_kwargs) elif provider == LLMProvider.OPENAI: return OpenAIClient(api_key=api_key, model=model, **config_kwargs) elif provider in (LLMProvider.DEEPSEEK, LLMProvider.QWEN, LLMProvider.KIMI, LLMProvider.OPENROUTER): base_urls = { LLMProvider.DEEPSEEK: "https://api.deepseek.com/v1", LLMProvider.QWEN: "https://dashscope.aliyuncs.com/compatible-mode/v1", LLMProvider.KIMI: "https://api.moonshot.cn/v1", LLMProvider.OPENROUTER: "https://openrouter.ai/api/v1", } return GenericAPIClient( provider=provider, api_key=api_key, base_url=base_url or base_urls[provider], default_model=model, **config_kwargs, ) elif provider == LLMProvider.CUSTOM: if not base_url: raise ValueError("base_url required for CUSTOM provider") return GenericAPIClient( provider=provider, api_key=api_key, base_url=base_url, default_model=model, **config_kwargs, ) else: raise ValueError(f"Unsupported provider: {provider}") @classmethod def get_or_create(cls, provider: LLMProvider, **kwargs) -> BaseLLMClient: """Get existing client or create new one (singleton per provider config).""" key = f"{provider.value}:{kwargs.get('model', 'default')}" if key not in cls._clients: cls._clients[key] = cls.create(provider, **kwargs) return cls._clients[key] @classmethod async def close_all(cls) -> None: """Close all managed clients.""" for client in cls._clients.values(): await client.close() cls._clients.clear() class LLMOrchestrator: """High-level orchestrator for multi-provider LLM inference.""" def __init__( self, default_provider: LLMProvider = LLMProvider.OPENAI, fallback_providers: Optional[List[LLMProvider]] = None, load_balance: bool = False, max_concurrent_requests: int = 10, default_model: Optional[str] = None, ): self.default_provider = default_provider self.fallback_providers = fallback_providers or [] self.load_balance = load_balance self.max_concurrent_requests = max_concurrent_requests self._provider_index = 0 self.default_model = default_model or "" async def generate( self, prompt: str, provider: Optional[LLMProvider] = None, model: Optional[str] = None, **request_kwargs, ) -> LLMResponse: """Generate text with automatic provider selection and fallback.""" if provider is None: if self.load_balance and self.fallback_providers: providers = [self.default_provider] + self.fallback_providers provider = providers[self._provider_index % len(providers)] self._provider_index += 1 else: provider = self.default_provider # Use orchestrator default model when none provided effective_model = model or self.default_model or "" request = LLMRequest( prompt=prompt, provider=provider, model=effective_model, **request_kwargs, ) try: client = LLMClientFactory.get_or_create( provider, model=request.model, max_concurrent_requests=self.max_concurrent_requests, ) return await client.generate(request) except LLMProviderError as e: logger.warning( f"Primary provider {provider.value} failed: {e.error_type} - {e.message}" ) if not e.retryable: raise for fallback in self.fallback_providers: if fallback == provider: continue try: logger.info(f"Trying fallback provider: {fallback.value}") client = LLMClientFactory.get_or_create( fallback, model=request.model, max_concurrent_requests=self.max_concurrent_requests, ) response = await client.generate(request) response.metadata["fallback_from"] = provider.value return response except LLMProviderError as e: logger.warning(f"Fallback provider {fallback.value} failed: {e.error_type}") continue raise LLMProviderError( message="All providers failed to generate response", error_type="all_providers_failed", provider=provider, model=request.model, retryable=False, ) async def generate_stream( self, prompt: str, provider: Optional[LLMProvider] = None, model: Optional[str] = None, **request_kwargs, ) -> AsyncGenerator[Dict[str, Any], None]: """Generate text with streaming support and automatic provider selection.""" if provider is None: provider = self.default_provider effective_model = model or self.default_model or "" request = LLMRequest( prompt=prompt, provider=provider, model=effective_model, **request_kwargs, ) client = LLMClientFactory.get_or_create( provider, model=request.model, max_concurrent_requests=self.max_concurrent_requests, ) async for chunk in client.generate_stream(request): yield chunk async def health_check_all(self) -> Dict[str, Dict[str, Any]]: """Perform health checks on all configured providers.""" results = {} for provider in [self.default_provider] + self.fallback_providers: try: client = LLMClientFactory.get_or_create(provider) results[provider.value] = await client.health_check() except Exception as e: results[provider.value] = { "status": "error", "error": str(e), } return results def get_metrics(self) -> Dict[str, Dict[str, Any]]: """Get metrics from all configured clients.""" metrics = {} for key, client in LLMClientFactory._clients.items(): metrics[key] = client.get_metrics() return metrics async def close(self) -> None: """Clean up all resources.""" await LLMClientFactory.close_all() # ============================================================================= # EXCEPTIONS # ============================================================================= class LLMProviderError(Exception): """Base exception for LLM provider errors.""" def __init__( self, message: str, error_type: str, provider: LLMProvider, model: Optional[str] = None, retryable: bool = False, details: Optional[Dict[str, Any]] = None, ): super().__init__(message) self.message = message self.error_type = error_type self.provider = provider self.model = model self.retryable = retryable self.details = details or {} self.timestamp = datetime.utcnow() def __str__(self) -> str: return ( f"[{self.provider.value}] {self.error_type}: {self.message}" f"{' (retryable)' if self.retryable else ''}" ) def to_dict(self) -> Dict[str, Any]: """Convert to dictionary for logging/serialization.""" return { "message": self.message, "error_type": self.error_type, "provider": self.provider.value, "model": self.model, "retryable": self.retryable, "details": self.details, "timestamp": self.timestamp.isoformat(), } # ============================================================================= # UTILITIES # ============================================================================= def validate_prompt(prompt: str, max_length: int = 100000) -> str: """Validate and sanitize prompt input.""" if not prompt or not isinstance(prompt, str): raise ValueError("prompt must be a non-empty string") prompt = prompt.strip() if not prompt: raise ValueError("prompt cannot be empty after stripping") if len(prompt) > max_length: raise ValueError( f"prompt exceeds maximum length of {max_length} characters " f"(got {len(prompt)})" ) return prompt def estimate_cost( input_tokens: int, output_tokens: int, provider: LLMProvider, model: str, ) -> float: """Estimate generation cost without making a request.""" pricing_key = f"{provider.value}/{model}" pricing = PROVIDER_PRICING.get( pricing_key, PROVIDER_PRICING.get(f"{provider.value}/*", ProviderPricing(0, 0)) ) return ( input_tokens * pricing.input_price + output_tokens * pricing.output_price ) / 1000