Spaces:
Running
Running
| """Application services for the Claude-compatible API.""" | |
| from __future__ import annotations | |
| import traceback | |
| import uuid | |
| from collections.abc import AsyncIterator, Callable | |
| from typing import Any | |
| from fastapi import HTTPException | |
| from fastapi.responses import StreamingResponse | |
| from loguru import logger | |
| from config.settings import Settings | |
| from core.anthropic import get_token_count, get_user_facing_error_message | |
| from core.anthropic.sse import ANTHROPIC_SSE_RESPONSE_HEADERS, format_sse_event | |
| from core.session_tracker import SessionTracker | |
| from providers.base import BaseProvider | |
| from providers.exceptions import ( | |
| InvalidRequestError, | |
| OverloadedError, | |
| ProviderError, | |
| RateLimitError, | |
| ) | |
| from .model_router import ModelRouter | |
| from .models.anthropic import MessagesRequest, TokenCountRequest | |
| from .models.responses import TokenCountResponse | |
| from .optimization_handlers import try_optimizations | |
| from .web_tools.egress import WebFetchEgressPolicy | |
| from .web_tools.request import ( | |
| is_web_server_tool_request, | |
| openai_chat_upstream_server_tool_error, | |
| ) | |
| TokenCounter = Callable[[list[Any], str | list[Any] | None, list[Any] | None], int] | |
| ProviderGetter = Callable[[str], BaseProvider] | |
| # Providers that use ``/chat/completions`` + Anthropic-to-OpenAI conversion (not native Messages). | |
| _OPENAI_CHAT_UPSTREAM_IDS = frozenset({"nvidia_nim", "groq", "cerebras"}) | |
| def anthropic_sse_streaming_response( | |
| body: AsyncIterator[str], | |
| ) -> StreamingResponse: | |
| """Return a :class:`StreamingResponse` for Anthropic-style SSE streams.""" | |
| return StreamingResponse( | |
| body, | |
| media_type="text/event-stream", | |
| headers=ANTHROPIC_SSE_RESPONSE_HEADERS, | |
| ) | |
| def _http_status_for_unexpected_service_exception(_exc: BaseException) -> int: | |
| """HTTP status for uncaught non-provider failures (stable client contract).""" | |
| return 500 | |
| def _log_unexpected_service_exception( | |
| settings: Settings, | |
| exc: BaseException, | |
| *, | |
| context: str, | |
| request_id: str | None = None, | |
| ) -> None: | |
| """Log service-layer failures without echoing exception text unless opted in.""" | |
| if settings.log_api_error_tracebacks: | |
| if request_id is not None: | |
| logger.error("{} request_id={}: {}", context, request_id, exc) | |
| else: | |
| logger.error("{}: {}", context, exc) | |
| logger.error(traceback.format_exc()) | |
| return | |
| if request_id is not None: | |
| logger.error( | |
| "{} request_id={} exc_type={}", | |
| context, | |
| request_id, | |
| type(exc).__name__, | |
| ) | |
| else: | |
| logger.error("{} exc_type={}", context, type(exc).__name__) | |
| def _require_non_empty_messages(messages: list[Any]) -> None: | |
| if not messages: | |
| raise InvalidRequestError("messages cannot be empty") | |
| class ClaudeProxyService: | |
| """Coordinate request optimization, model routing, token count, and providers.""" | |
| def __init__( | |
| self, | |
| settings: Settings, | |
| provider_getter: ProviderGetter, | |
| model_router: ModelRouter | None = None, | |
| token_counter: TokenCounter = get_token_count, | |
| ): | |
| self._settings = settings | |
| self._provider_getter = provider_getter | |
| self._model_router = model_router or ModelRouter(settings) | |
| self._token_counter = token_counter | |
| self._session_tracker = SessionTracker.get_instance() | |
| def _get_session_id(self, request_data: MessagesRequest) -> str: | |
| """Extract or generate a session ID from the request.""" | |
| # Try to extract session ID from messages metadata or generate one | |
| # This allows multiple Claude Code instances to share the proxy fairly | |
| if hasattr(request_data, 'custom_id'): | |
| return str(request_data.custom_id) | |
| return f"session_{uuid.uuid4().hex[:12]}" | |
| def create_message(self, request_data: MessagesRequest) -> object: | |
| """Create a message response or streaming response with optional failover.""" | |
| from .web_tools.streaming import stream_web_server_tool_response | |
| try: | |
| _require_non_empty_messages(request_data.messages) | |
| candidates = self._model_router.resolve_candidates(request_data.model) | |
| if not candidates: | |
| raise InvalidRequestError(f"No configured models available for '{request_data.model}'") | |
| # For 'auto' requests with multiple candidates, we wrap the stream in a failover loop. | |
| if len(candidates) > 1: | |
| return anthropic_sse_streaming_response( | |
| self._stream_with_fallbacks(candidates, request_data) | |
| ) | |
| # Standard path for single-model requests | |
| return self._create_single_message(candidates[0], request_data) | |
| except ProviderError: | |
| raise | |
| except Exception as e: | |
| _log_unexpected_service_exception( | |
| self._settings, e, context="CREATE_MESSAGE_ERROR" | |
| ) | |
| raise HTTPException( | |
| status_code=_http_status_for_unexpected_service_exception(e), | |
| detail=get_user_facing_error_message(e), | |
| ) from e | |
| def _create_single_message( | |
| self, resolved: ResolvedModel, request_data: MessagesRequest | |
| ) -> object: | |
| """Create a single message response from a resolved model.""" | |
| routed_request = request_data.model_copy(deep=True) | |
| routed_request.model = resolved.provider_model | |
| if resolved.provider_id in _OPENAI_CHAT_UPSTREAM_IDS: | |
| tool_err = openai_chat_upstream_server_tool_error( | |
| routed_request, | |
| web_tools_enabled=self._settings.enable_web_server_tools, | |
| ) | |
| if tool_err is not None: | |
| raise InvalidRequestError(tool_err) | |
| if self._settings.enable_web_server_tools and is_web_server_tool_request( | |
| routed_request | |
| ): | |
| input_tokens = self._token_counter( | |
| routed_request.messages, routed_request.system, routed_request.tools | |
| ) | |
| logger.info("Optimization: Handling Anthropic web server tool") | |
| egress = WebFetchEgressPolicy( | |
| allow_private_network_targets=self._settings.web_fetch_allow_private_networks, | |
| allowed_schemes=self._settings.web_fetch_allowed_scheme_set(), | |
| ) | |
| return anthropic_sse_streaming_response( | |
| stream_web_server_tool_response( | |
| routed_request, | |
| input_tokens=input_tokens, | |
| web_fetch_egress=egress, | |
| verbose_client_errors=self._settings.log_api_error_tracebacks, | |
| ), | |
| ) | |
| optimized = try_optimizations(routed_request, self._settings) | |
| if optimized is not None: | |
| return optimized | |
| provider = self._provider_getter(resolved.provider_id) | |
| provider.preflight_stream( | |
| routed_request, | |
| thinking_enabled=resolved.thinking_enabled, | |
| ) | |
| session_id = self._get_session_id(request_data) | |
| self._session_tracker.track_request_sync(session_id, resolved.provider_id) | |
| request_id = f"req_{uuid.uuid4().hex[:12]}" | |
| logger.info( | |
| "API_REQUEST: request_id={} model={} messages={}", | |
| request_id, | |
| routed_request.model, | |
| len(routed_request.messages), | |
| ) | |
| input_tokens = self._token_counter( | |
| routed_request.messages, routed_request.system, routed_request.tools | |
| ) | |
| return anthropic_sse_streaming_response( | |
| provider.stream_response( | |
| routed_request, | |
| input_tokens=input_tokens, | |
| request_id=request_id, | |
| thinking_enabled=resolved.thinking_enabled, | |
| ), | |
| ) | |
| async def _stream_with_fallbacks( | |
| self, candidates: list[ResolvedModel], request_data: MessagesRequest | |
| ) -> AsyncIterator[str]: | |
| """Iterate through candidates until one succeeds or all fail.""" | |
| last_exc: Exception | None = None | |
| for i, resolved in enumerate(candidates): | |
| try: | |
| provider = self._provider_getter(resolved.provider_id) | |
| routed_request = request_data.model_copy(deep=True) | |
| routed_request.model = resolved.provider_model | |
| provider.preflight_stream( | |
| routed_request, | |
| thinking_enabled=resolved.thinking_enabled, | |
| ) | |
| request_id = f"req_{uuid.uuid4().hex[:12]}" | |
| logger.info( | |
| "API_REQUEST (auto fallback {}/{}): request_id={} provider={} model={}", | |
| i + 1, | |
| len(candidates), | |
| request_id, | |
| resolved.provider_id, | |
| resolved.provider_model, | |
| ) | |
| input_tokens = self._token_counter( | |
| routed_request.messages, routed_request.system, routed_request.tools | |
| ) | |
| # Attempt to stream from this provider. | |
| async for event in provider.stream_response( | |
| routed_request, | |
| input_tokens=input_tokens, | |
| request_id=request_id, | |
| thinking_enabled=resolved.thinking_enabled, | |
| ): | |
| yield event | |
| # CRITICAL: If we have yielded even one event, we have committed to this provider. | |
| # We must not fallback to another candidate mid-stream. | |
| return # Success, exit the fallback loop. | |
| except (RateLimitError, OverloadedError) as e: | |
| logger.warning( | |
| "Provider '{}' is rate limited or overloaded ({}). Trying next candidate...", | |
| resolved.provider_id, | |
| e.status_code, | |
| ) | |
| last_exc = e | |
| continue | |
| except Exception as e: | |
| logger.error( | |
| "Provider '{}' failed with unexpected error: {}. Trying next candidate...", | |
| resolved.provider_id, | |
| e, | |
| ) | |
| last_exc = e | |
| continue | |
| err_msg = str(last_exc) if last_exc else "No candidates succeeded" | |
| yield format_sse_event( | |
| "error", | |
| { | |
| "type": "error", | |
| "error": { | |
| "type": "api_error", | |
| "message": f"All fallback candidates failed: {err_msg}", | |
| }, | |
| }, | |
| ) | |
| if last_exc: | |
| raise last_exc | |
| raise InvalidRequestError("No candidates succeeded") | |
| def count_tokens(self, request_data: TokenCountRequest) -> TokenCountResponse: | |
| """Count tokens for a request after applying configured model routing.""" | |
| request_id = f"req_{uuid.uuid4().hex[:12]}" | |
| with logger.contextualize(request_id=request_id): | |
| try: | |
| _require_non_empty_messages(request_data.messages) | |
| routed = self._model_router.resolve_token_count_request(request_data) | |
| tokens = self._token_counter( | |
| routed.request.messages, routed.request.system, routed.request.tools | |
| ) | |
| logger.info( | |
| "COUNT_TOKENS: request_id={} model={} messages={} input_tokens={}", | |
| request_id, | |
| routed.request.model, | |
| len(routed.request.messages), | |
| tokens, | |
| ) | |
| return TokenCountResponse(input_tokens=tokens) | |
| except ProviderError: | |
| raise | |
| except Exception as e: | |
| _log_unexpected_service_exception( | |
| self._settings, | |
| e, | |
| context="COUNT_TOKENS_ERROR", | |
| request_id=request_id, | |
| ) | |
| raise HTTPException( | |
| status_code=_http_status_for_unexpected_service_exception(e), | |
| detail=get_user_facing_error_message(e), | |
| ) from e | |