claude-code-proxy / api /services.py
Yash030's picture
Fix: use sync version of track_request to avoid SyntaxError
b5bd2a8
raw
history blame
12.3 kB
"""Application services for the Claude-compatible API."""
from __future__ import annotations
import traceback
import uuid
from collections.abc import AsyncIterator, Callable
from typing import Any
from fastapi import HTTPException
from fastapi.responses import StreamingResponse
from loguru import logger
from config.settings import Settings
from core.anthropic import get_token_count, get_user_facing_error_message
from core.anthropic.sse import ANTHROPIC_SSE_RESPONSE_HEADERS, format_sse_event
from core.session_tracker import SessionTracker
from providers.base import BaseProvider
from providers.exceptions import (
InvalidRequestError,
OverloadedError,
ProviderError,
RateLimitError,
)
from .model_router import ModelRouter
from .models.anthropic import MessagesRequest, TokenCountRequest
from .models.responses import TokenCountResponse
from .optimization_handlers import try_optimizations
from .web_tools.egress import WebFetchEgressPolicy
from .web_tools.request import (
is_web_server_tool_request,
openai_chat_upstream_server_tool_error,
)
TokenCounter = Callable[[list[Any], str | list[Any] | None, list[Any] | None], int]
ProviderGetter = Callable[[str], BaseProvider]
# Providers that use ``/chat/completions`` + Anthropic-to-OpenAI conversion (not native Messages).
_OPENAI_CHAT_UPSTREAM_IDS = frozenset({"nvidia_nim", "groq", "cerebras"})
def anthropic_sse_streaming_response(
body: AsyncIterator[str],
) -> StreamingResponse:
"""Return a :class:`StreamingResponse` for Anthropic-style SSE streams."""
return StreamingResponse(
body,
media_type="text/event-stream",
headers=ANTHROPIC_SSE_RESPONSE_HEADERS,
)
def _http_status_for_unexpected_service_exception(_exc: BaseException) -> int:
"""HTTP status for uncaught non-provider failures (stable client contract)."""
return 500
def _log_unexpected_service_exception(
settings: Settings,
exc: BaseException,
*,
context: str,
request_id: str | None = None,
) -> None:
"""Log service-layer failures without echoing exception text unless opted in."""
if settings.log_api_error_tracebacks:
if request_id is not None:
logger.error("{} request_id={}: {}", context, request_id, exc)
else:
logger.error("{}: {}", context, exc)
logger.error(traceback.format_exc())
return
if request_id is not None:
logger.error(
"{} request_id={} exc_type={}",
context,
request_id,
type(exc).__name__,
)
else:
logger.error("{} exc_type={}", context, type(exc).__name__)
def _require_non_empty_messages(messages: list[Any]) -> None:
if not messages:
raise InvalidRequestError("messages cannot be empty")
class ClaudeProxyService:
"""Coordinate request optimization, model routing, token count, and providers."""
def __init__(
self,
settings: Settings,
provider_getter: ProviderGetter,
model_router: ModelRouter | None = None,
token_counter: TokenCounter = get_token_count,
):
self._settings = settings
self._provider_getter = provider_getter
self._model_router = model_router or ModelRouter(settings)
self._token_counter = token_counter
self._session_tracker = SessionTracker.get_instance()
def _get_session_id(self, request_data: MessagesRequest) -> str:
"""Extract or generate a session ID from the request."""
# Try to extract session ID from messages metadata or generate one
# This allows multiple Claude Code instances to share the proxy fairly
if hasattr(request_data, 'custom_id'):
return str(request_data.custom_id)
return f"session_{uuid.uuid4().hex[:12]}"
def create_message(self, request_data: MessagesRequest) -> object:
"""Create a message response or streaming response with optional failover."""
from .web_tools.streaming import stream_web_server_tool_response
try:
_require_non_empty_messages(request_data.messages)
candidates = self._model_router.resolve_candidates(request_data.model)
if not candidates:
raise InvalidRequestError(f"No configured models available for '{request_data.model}'")
# For 'auto' requests with multiple candidates, we wrap the stream in a failover loop.
if len(candidates) > 1:
return anthropic_sse_streaming_response(
self._stream_with_fallbacks(candidates, request_data)
)
# Standard path for single-model requests
return self._create_single_message(candidates[0], request_data)
except ProviderError:
raise
except Exception as e:
_log_unexpected_service_exception(
self._settings, e, context="CREATE_MESSAGE_ERROR"
)
raise HTTPException(
status_code=_http_status_for_unexpected_service_exception(e),
detail=get_user_facing_error_message(e),
) from e
def _create_single_message(
self, resolved: ResolvedModel, request_data: MessagesRequest
) -> object:
"""Create a single message response from a resolved model."""
routed_request = request_data.model_copy(deep=True)
routed_request.model = resolved.provider_model
if resolved.provider_id in _OPENAI_CHAT_UPSTREAM_IDS:
tool_err = openai_chat_upstream_server_tool_error(
routed_request,
web_tools_enabled=self._settings.enable_web_server_tools,
)
if tool_err is not None:
raise InvalidRequestError(tool_err)
if self._settings.enable_web_server_tools and is_web_server_tool_request(
routed_request
):
input_tokens = self._token_counter(
routed_request.messages, routed_request.system, routed_request.tools
)
logger.info("Optimization: Handling Anthropic web server tool")
egress = WebFetchEgressPolicy(
allow_private_network_targets=self._settings.web_fetch_allow_private_networks,
allowed_schemes=self._settings.web_fetch_allowed_scheme_set(),
)
return anthropic_sse_streaming_response(
stream_web_server_tool_response(
routed_request,
input_tokens=input_tokens,
web_fetch_egress=egress,
verbose_client_errors=self._settings.log_api_error_tracebacks,
),
)
optimized = try_optimizations(routed_request, self._settings)
if optimized is not None:
return optimized
provider = self._provider_getter(resolved.provider_id)
provider.preflight_stream(
routed_request,
thinking_enabled=resolved.thinking_enabled,
)
session_id = self._get_session_id(request_data)
self._session_tracker.track_request_sync(session_id, resolved.provider_id)
request_id = f"req_{uuid.uuid4().hex[:12]}"
logger.info(
"API_REQUEST: request_id={} model={} messages={}",
request_id,
routed_request.model,
len(routed_request.messages),
)
input_tokens = self._token_counter(
routed_request.messages, routed_request.system, routed_request.tools
)
return anthropic_sse_streaming_response(
provider.stream_response(
routed_request,
input_tokens=input_tokens,
request_id=request_id,
thinking_enabled=resolved.thinking_enabled,
),
)
async def _stream_with_fallbacks(
self, candidates: list[ResolvedModel], request_data: MessagesRequest
) -> AsyncIterator[str]:
"""Iterate through candidates until one succeeds or all fail."""
last_exc: Exception | None = None
for i, resolved in enumerate(candidates):
try:
provider = self._provider_getter(resolved.provider_id)
routed_request = request_data.model_copy(deep=True)
routed_request.model = resolved.provider_model
provider.preflight_stream(
routed_request,
thinking_enabled=resolved.thinking_enabled,
)
request_id = f"req_{uuid.uuid4().hex[:12]}"
logger.info(
"API_REQUEST (auto fallback {}/{}): request_id={} provider={} model={}",
i + 1,
len(candidates),
request_id,
resolved.provider_id,
resolved.provider_model,
)
input_tokens = self._token_counter(
routed_request.messages, routed_request.system, routed_request.tools
)
# Attempt to stream from this provider.
async for event in provider.stream_response(
routed_request,
input_tokens=input_tokens,
request_id=request_id,
thinking_enabled=resolved.thinking_enabled,
):
yield event
# CRITICAL: If we have yielded even one event, we have committed to this provider.
# We must not fallback to another candidate mid-stream.
return # Success, exit the fallback loop.
except (RateLimitError, OverloadedError) as e:
logger.warning(
"Provider '{}' is rate limited or overloaded ({}). Trying next candidate...",
resolved.provider_id,
e.status_code,
)
last_exc = e
continue
except Exception as e:
logger.error(
"Provider '{}' failed with unexpected error: {}. Trying next candidate...",
resolved.provider_id,
e,
)
last_exc = e
continue
err_msg = str(last_exc) if last_exc else "No candidates succeeded"
yield format_sse_event(
"error",
{
"type": "error",
"error": {
"type": "api_error",
"message": f"All fallback candidates failed: {err_msg}",
},
},
)
if last_exc:
raise last_exc
raise InvalidRequestError("No candidates succeeded")
def count_tokens(self, request_data: TokenCountRequest) -> TokenCountResponse:
"""Count tokens for a request after applying configured model routing."""
request_id = f"req_{uuid.uuid4().hex[:12]}"
with logger.contextualize(request_id=request_id):
try:
_require_non_empty_messages(request_data.messages)
routed = self._model_router.resolve_token_count_request(request_data)
tokens = self._token_counter(
routed.request.messages, routed.request.system, routed.request.tools
)
logger.info(
"COUNT_TOKENS: request_id={} model={} messages={} input_tokens={}",
request_id,
routed.request.model,
len(routed.request.messages),
tokens,
)
return TokenCountResponse(input_tokens=tokens)
except ProviderError:
raise
except Exception as e:
_log_unexpected_service_exception(
self._settings,
e,
context="COUNT_TOKENS_ERROR",
request_id=request_id,
)
raise HTTPException(
status_code=_http_status_for_unexpected_service_exception(e),
detail=get_user_facing_error_message(e),
) from e