Spaces:
Running
Running
| """Application services for the Claude-compatible API.""" | |
| from __future__ import annotations | |
| import traceback | |
| import uuid | |
| from collections.abc import AsyncIterator, Callable | |
| from typing import Any | |
| from fastapi import HTTPException, Request | |
| from fastapi.responses import StreamingResponse | |
| from loguru import logger | |
| from config.settings import Settings, get_settings | |
| from core.anthropic import get_token_count, get_user_facing_error_message | |
| from core.anthropic.sse import ANTHROPIC_SSE_RESPONSE_HEADERS, format_sse_event | |
| from core.session_tracker import SessionTracker | |
| from providers.base import BaseProvider | |
| from providers.exceptions import ( | |
| InvalidRequestError, | |
| OverloadedError, | |
| ProviderError, | |
| RateLimitError, | |
| ) | |
| from .model_router import ModelRouter, ResolvedModel | |
| from .models.anthropic import MessagesRequest, TokenCountRequest | |
| from .models.responses import TokenCountResponse | |
| from .optimization_handlers import try_optimizations | |
| from .web_tools.egress import WebFetchEgressPolicy | |
| from .web_tools.request import ( | |
| is_web_server_tool_request, | |
| openai_chat_upstream_server_tool_error, | |
| ) | |
| TokenCounter = Callable[[list[Any], str | list[Any] | None, list[Any] | None], int] | |
| ProviderGetter = Callable[[str], BaseProvider] | |
| # Providers that use ``/chat/completions`` + Anthropic-to-OpenAI conversion (not native Messages). | |
| _OPENAI_CHAT_UPSTREAM_IDS = frozenset({"nvidia_nim", "groq", "cerebras", "silicon"}) | |
| def anthropic_sse_streaming_response( | |
| body: AsyncIterator[str], | |
| ) -> StreamingResponse: | |
| """Return a :class:`StreamingResponse` for Anthropic-style SSE streams.""" | |
| return StreamingResponse( | |
| body, | |
| media_type="text/event-stream", | |
| headers=ANTHROPIC_SSE_RESPONSE_HEADERS, | |
| ) | |
| def _http_status_for_unexpected_service_exception(_exc: BaseException) -> int: | |
| """HTTP status for uncaught non-provider failures (stable client contract).""" | |
| return 500 | |
| def _log_unexpected_service_exception( | |
| settings: Settings, | |
| exc: BaseException, | |
| *, | |
| context: str, | |
| request_id: str | None = None, | |
| ) -> None: | |
| """Log service-layer failures without echoing exception text unless opted in.""" | |
| if settings.log_api_error_tracebacks: | |
| if request_id is not None: | |
| logger.error("{} request_id={}: {}", context, request_id, exc) | |
| else: | |
| logger.error("{}: {}", context, exc) | |
| logger.error(traceback.format_exc()) | |
| return | |
| if request_id is not None: | |
| logger.error( | |
| "{} request_id={} exc_type={}", | |
| context, | |
| request_id, | |
| type(exc).__name__, | |
| ) | |
| else: | |
| logger.error("{} exc_type={}", context, type(exc).__name__) | |
| def _require_non_empty_messages(messages: list[Any]) -> None: | |
| if not messages: | |
| raise InvalidRequestError("messages cannot be empty") | |
| def _get_client_ip(request: Request) -> str | None: | |
| """Extract client IP from gateway headers or return None for direct connections.""" | |
| # Check for proxy/gateway headers | |
| forwarded = request.headers.get("X-Forwarded-For") | |
| if forwarded: | |
| return forwarded.split(",")[0].strip() | |
| real_ip = request.headers.get("X-Real-IP") | |
| if real_ip: | |
| return real_ip | |
| client_ip = request.headers.get("X-Client-IP") | |
| if client_ip: | |
| return client_ip | |
| via = request.headers.get("Via") | |
| if via: | |
| return request.client.host # Gateway/proxy IP | |
| return None # Direct connection | |
| def _get_session_id(request: Request) -> str: | |
| """Get session ID from X-Session-ID header or fall back to gateway IP. | |
| Claude Code sends X-Session-ID when started with --session-id <uuid>. | |
| """ | |
| session = request.headers.get("X-Session-ID") | |
| if session: | |
| return session | |
| ip = _get_client_ip(request) | |
| return f"gateway_{ip}" if ip else "direct" | |
| class ClaudeProxyService: | |
| """Coordinate request optimization, model routing, and providers.""" | |
| def __init__( | |
| self, | |
| settings: Settings, | |
| provider_getter: ProviderGetter, | |
| model_router: ModelRouter | None = None, | |
| token_counter: TokenCounter = get_token_count, | |
| ): | |
| self._settings = settings | |
| self._provider_getter = provider_getter | |
| self._model_router = model_router or ModelRouter(settings) | |
| self._token_counter = token_counter | |
| settings_local = get_settings() | |
| self._session_tracker = SessionTracker.get_instance( | |
| retention_seconds=settings_local.session_retention_minutes * 60 | |
| ) | |
| def create_message(self, request: Request, request_data: MessagesRequest) -> object: | |
| """Create a message response or streaming response with optional failover.""" | |
| try: | |
| _require_non_empty_messages(request_data.messages) | |
| candidates = self._model_router.resolve_candidates(request_data.model) | |
| if not candidates: | |
| raise InvalidRequestError( | |
| f"No configured models available for '{request_data.model}'" | |
| ) | |
| # Debug log what we're routing to | |
| from loguru import logger | |
| logger.info( | |
| "REQUEST_MODEL_ROUTING: requested={} resolved_provider={} resolved_model={}", | |
| request_data.model, | |
| candidates[0].provider_id, | |
| candidates[0].provider_model, | |
| ) | |
| # For 'auto' requests with multiple candidates, we wrap the stream in a failover loop. | |
| if len(candidates) > 1: | |
| return anthropic_sse_streaming_response( | |
| self._stream_with_fallbacks(request, candidates, request_data) | |
| ) | |
| # Standard path for single-model requests | |
| return self._create_single_message(request, candidates[0], request_data) | |
| except ProviderError: | |
| raise | |
| except Exception as e: | |
| _log_unexpected_service_exception( | |
| self._settings, e, context="CREATE_MESSAGE_ERROR" | |
| ) | |
| raise HTTPException( | |
| status_code=_http_status_for_unexpected_service_exception(e), | |
| detail=get_user_facing_error_message(e), | |
| ) from e | |
| def _create_single_message( | |
| self, request: Request, resolved: ResolvedModel, request_data: MessagesRequest | |
| ) -> object: | |
| """Create a single message response from a resolved model.""" | |
| routed_request = request_data.model_copy(deep=True) | |
| routed_request.model = resolved.provider_model | |
| if resolved.provider_id in _OPENAI_CHAT_UPSTREAM_IDS: | |
| tool_err = openai_chat_upstream_server_tool_error( | |
| routed_request, | |
| web_tools_enabled=self._settings.enable_web_server_tools, | |
| ) | |
| if tool_err is not None: | |
| raise InvalidRequestError(tool_err) | |
| if self._settings.enable_web_server_tools and is_web_server_tool_request( | |
| routed_request | |
| ): | |
| from .web_tools.streaming import stream_web_server_tool_response | |
| input_tokens = self._token_counter( | |
| routed_request.messages, routed_request.system, routed_request.tools | |
| ) | |
| logger.info("Optimization: Handling Anthropic web server tool") | |
| egress = WebFetchEgressPolicy( | |
| allow_private_network_targets=self._settings.web_fetch_allow_private_networks, | |
| allowed_schemes=self._settings.web_fetch_allowed_scheme_set(), | |
| ) | |
| return anthropic_sse_streaming_response( | |
| stream_web_server_tool_response( | |
| routed_request, | |
| input_tokens=input_tokens, | |
| web_fetch_egress=egress, | |
| verbose_client_errors=self._settings.log_api_error_tracebacks, | |
| ), | |
| ) | |
| optimized = try_optimizations(routed_request, self._settings) | |
| if optimized is not None: | |
| return optimized | |
| provider = self._provider_getter(resolved.provider_id) | |
| provider.preflight_stream( | |
| routed_request, | |
| thinking_enabled=resolved.thinking_enabled, | |
| ) | |
| session_id = _get_session_id(request) | |
| self._session_tracker.track_request_sync(session_id, resolved.provider_id) | |
| request_id = f"req_{uuid.uuid4().hex[:12]}" | |
| logger.info( | |
| "API_REQUEST: request_id={} model={} messages={}", | |
| request_id, | |
| routed_request.model, | |
| len(routed_request.messages), | |
| ) | |
| input_tokens = self._token_counter( | |
| routed_request.messages, routed_request.system, routed_request.tools | |
| ) | |
| return anthropic_sse_streaming_response( | |
| provider.stream_response( | |
| routed_request, | |
| input_tokens=input_tokens, | |
| request_id=request_id, | |
| thinking_enabled=resolved.thinking_enabled, | |
| ), | |
| ) | |
| async def _stream_with_fallbacks( | |
| self, | |
| request: Request, | |
| candidates: list[ResolvedModel], | |
| request_data: MessagesRequest, | |
| ) -> AsyncIterator[str]: | |
| """Iterate through candidates until one succeeds or all fail.""" | |
| last_exc: Exception | None = None | |
| for i, resolved in enumerate(candidates): | |
| try: | |
| # Pre-check: skip candidates that are currently rate limited or unhealthy | |
| from providers.rate_limit import GlobalRateLimiter | |
| limiter = GlobalRateLimiter.get_scoped_instance(resolved.provider_id) | |
| if limiter.is_blocked() and resolved.provider_id != "zen": | |
| # Silently skip — no failure penalty for temporary rate limit | |
| logger.debug( | |
| "Skipping blocked provider '{}' (no penalty)", | |
| resolved.provider_id, | |
| ) | |
| continue | |
| # Check model health (recent failures) | |
| if not limiter.is_healthy(resolved.provider_model_ref): | |
| logger.warning( | |
| "Provider '{}' has recent failures, skipping to next candidate...", | |
| resolved.provider_model_ref, | |
| ) | |
| last_exc = Exception("Recent failures") | |
| continue | |
| provider = self._provider_getter(resolved.provider_id) | |
| routed_request = request_data.model_copy(deep=True) | |
| routed_request.model = resolved.provider_model | |
| provider.preflight_stream( | |
| routed_request, | |
| thinking_enabled=resolved.thinking_enabled, | |
| ) | |
| session_id = _get_session_id(request) | |
| self._session_tracker.track_request_sync( | |
| session_id, resolved.provider_id | |
| ) | |
| request_id = f"req_{uuid.uuid4().hex[:12]}" | |
| logger.info( | |
| "API_REQUEST (auto fallback {}/{}): request_id={} provider={} model={}", | |
| i + 1, | |
| len(candidates), | |
| request_id, | |
| resolved.provider_id, | |
| resolved.provider_model, | |
| ) | |
| input_tokens = self._token_counter( | |
| routed_request.messages, routed_request.system, routed_request.tools | |
| ) | |
| # Attempt to stream from this provider. | |
| async for event in provider.stream_response( | |
| routed_request, | |
| input_tokens=input_tokens, | |
| request_id=request_id, | |
| thinking_enabled=resolved.thinking_enabled, | |
| ): | |
| yield event | |
| # CRITICAL: If we have yielded even one event, we have committed to this provider. | |
| # We must not fallback to another candidate mid-stream. | |
| return # Success, exit the fallback loop. | |
| except (RateLimitError, OverloadedError) as e: | |
| logger.warning( | |
| "Provider '{}' is rate limited or overloaded ({}). Trying next candidate...", | |
| resolved.provider_id, | |
| e.status_code, | |
| ) | |
| limiter.record_failure(resolved.provider_model_ref) | |
| last_exc = e | |
| continue | |
| except TimeoutError as e: | |
| # Timeout = slow model, try next candidate for faster response | |
| logger.warning( | |
| "Provider '{}' timed out ({}). Trying next candidate...", | |
| resolved.provider_id, | |
| type(e).__name__, | |
| ) | |
| limiter.record_failure(resolved.provider_model_ref) | |
| last_exc = e | |
| continue | |
| except Exception as e: | |
| # Check if it's a transient error that should trigger fallback | |
| error_str = str(e).lower() | |
| is_transient = any( | |
| kw in error_str | |
| for kw in [ | |
| "timeout", | |
| "connection", | |
| "refused", | |
| "reset", | |
| "unavailable", | |
| "service", | |
| ] | |
| ) | |
| if is_transient: | |
| logger.warning( | |
| "Provider '{}' failed with transient error ({}): {}. Trying next candidate...", | |
| resolved.provider_id, | |
| type(e).__name__, | |
| e, | |
| ) | |
| limiter.record_failure(resolved.provider_model_ref) | |
| last_exc = e | |
| continue | |
| logger.error( | |
| "Provider '{}' failed with unexpected error: {}. Trying next candidate...", | |
| resolved.provider_id, | |
| e, | |
| ) | |
| last_exc = e | |
| continue | |
| err_msg = str(last_exc) if last_exc else "No candidates succeeded" | |
| yield format_sse_event( | |
| "error", | |
| { | |
| "type": "error", | |
| "error": { | |
| "type": "api_error", | |
| "message": f"All fallback candidates failed: {err_msg}", | |
| }, | |
| }, | |
| ) | |
| if last_exc: | |
| raise last_exc | |
| raise InvalidRequestError("No candidates succeeded") | |
| def count_tokens(self, request_data: TokenCountRequest) -> TokenCountResponse: | |
| """Count tokens for a request after applying configured model routing.""" | |
| request_id = f"req_{uuid.uuid4().hex[:12]}" | |
| with logger.contextualize(request_id=request_id): | |
| try: | |
| _require_non_empty_messages(request_data.messages) | |
| routed = self._model_router.resolve_token_count_request(request_data) | |
| tokens = self._token_counter( | |
| routed.request.messages, routed.request.system, routed.request.tools | |
| ) | |
| logger.info( | |
| "COUNT_TOKENS: request_id={} model={} messages={} input_tokens={}", | |
| request_id, | |
| routed.request.model, | |
| len(routed.request.messages), | |
| tokens, | |
| ) | |
| return TokenCountResponse(input_tokens=tokens) | |
| except ProviderError: | |
| raise | |
| except Exception as e: | |
| _log_unexpected_service_exception( | |
| self._settings, | |
| e, | |
| context="COUNT_TOKENS_ERROR", | |
| request_id=request_id, | |
| ) | |
| raise HTTPException( | |
| status_code=_http_status_for_unexpected_service_exception(e), | |
| detail=get_user_facing_error_message(e), | |
| ) from e | |