claude-code-proxy / api /services.py
Yash030's picture
Add routing debug log to trace provider resolution.
65739aa
"""Application services for the Claude-compatible API."""
from __future__ import annotations
import traceback
import uuid
from collections.abc import AsyncIterator, Callable
from typing import Any
from fastapi import HTTPException, Request
from fastapi.responses import StreamingResponse
from loguru import logger
from config.settings import Settings, get_settings
from core.anthropic import get_token_count, get_user_facing_error_message
from core.anthropic.sse import ANTHROPIC_SSE_RESPONSE_HEADERS, format_sse_event
from core.session_tracker import SessionTracker
from providers.base import BaseProvider
from providers.exceptions import (
InvalidRequestError,
OverloadedError,
ProviderError,
RateLimitError,
)
from .model_router import ModelRouter, ResolvedModel
from .models.anthropic import MessagesRequest, TokenCountRequest
from .models.responses import TokenCountResponse
from .optimization_handlers import try_optimizations
from .web_tools.egress import WebFetchEgressPolicy
from .web_tools.request import (
is_web_server_tool_request,
openai_chat_upstream_server_tool_error,
)
TokenCounter = Callable[[list[Any], str | list[Any] | None, list[Any] | None], int]
ProviderGetter = Callable[[str], BaseProvider]
# Providers that use ``/chat/completions`` + Anthropic-to-OpenAI conversion (not native Messages).
_OPENAI_CHAT_UPSTREAM_IDS = frozenset({"nvidia_nim", "groq", "cerebras", "silicon"})
def anthropic_sse_streaming_response(
body: AsyncIterator[str],
) -> StreamingResponse:
"""Return a :class:`StreamingResponse` for Anthropic-style SSE streams."""
return StreamingResponse(
body,
media_type="text/event-stream",
headers=ANTHROPIC_SSE_RESPONSE_HEADERS,
)
def _http_status_for_unexpected_service_exception(_exc: BaseException) -> int:
"""HTTP status for uncaught non-provider failures (stable client contract)."""
return 500
def _log_unexpected_service_exception(
settings: Settings,
exc: BaseException,
*,
context: str,
request_id: str | None = None,
) -> None:
"""Log service-layer failures without echoing exception text unless opted in."""
if settings.log_api_error_tracebacks:
if request_id is not None:
logger.error("{} request_id={}: {}", context, request_id, exc)
else:
logger.error("{}: {}", context, exc)
logger.error(traceback.format_exc())
return
if request_id is not None:
logger.error(
"{} request_id={} exc_type={}",
context,
request_id,
type(exc).__name__,
)
else:
logger.error("{} exc_type={}", context, type(exc).__name__)
def _require_non_empty_messages(messages: list[Any]) -> None:
if not messages:
raise InvalidRequestError("messages cannot be empty")
def _get_client_ip(request: Request) -> str | None:
"""Extract client IP from gateway headers or return None for direct connections."""
# Check for proxy/gateway headers
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
client_ip = request.headers.get("X-Client-IP")
if client_ip:
return client_ip
via = request.headers.get("Via")
if via:
return request.client.host # Gateway/proxy IP
return None # Direct connection
def _get_session_id(request: Request) -> str:
"""Get session ID from X-Session-ID header or fall back to gateway IP.
Claude Code sends X-Session-ID when started with --session-id <uuid>.
"""
session = request.headers.get("X-Session-ID")
if session:
return session
ip = _get_client_ip(request)
return f"gateway_{ip}" if ip else "direct"
class ClaudeProxyService:
"""Coordinate request optimization, model routing, and providers."""
def __init__(
self,
settings: Settings,
provider_getter: ProviderGetter,
model_router: ModelRouter | None = None,
token_counter: TokenCounter = get_token_count,
):
self._settings = settings
self._provider_getter = provider_getter
self._model_router = model_router or ModelRouter(settings)
self._token_counter = token_counter
settings_local = get_settings()
self._session_tracker = SessionTracker.get_instance(
retention_seconds=settings_local.session_retention_minutes * 60
)
def create_message(self, request: Request, request_data: MessagesRequest) -> object:
"""Create a message response or streaming response with optional failover."""
try:
_require_non_empty_messages(request_data.messages)
candidates = self._model_router.resolve_candidates(request_data.model)
if not candidates:
raise InvalidRequestError(
f"No configured models available for '{request_data.model}'"
)
# Debug log what we're routing to
from loguru import logger
logger.info(
"REQUEST_MODEL_ROUTING: requested={} resolved_provider={} resolved_model={}",
request_data.model,
candidates[0].provider_id,
candidates[0].provider_model,
)
# For 'auto' requests with multiple candidates, we wrap the stream in a failover loop.
if len(candidates) > 1:
return anthropic_sse_streaming_response(
self._stream_with_fallbacks(request, candidates, request_data)
)
# Standard path for single-model requests
return self._create_single_message(request, candidates[0], request_data)
except ProviderError:
raise
except Exception as e:
_log_unexpected_service_exception(
self._settings, e, context="CREATE_MESSAGE_ERROR"
)
raise HTTPException(
status_code=_http_status_for_unexpected_service_exception(e),
detail=get_user_facing_error_message(e),
) from e
def _create_single_message(
self, request: Request, resolved: ResolvedModel, request_data: MessagesRequest
) -> object:
"""Create a single message response from a resolved model."""
routed_request = request_data.model_copy(deep=True)
routed_request.model = resolved.provider_model
if resolved.provider_id in _OPENAI_CHAT_UPSTREAM_IDS:
tool_err = openai_chat_upstream_server_tool_error(
routed_request,
web_tools_enabled=self._settings.enable_web_server_tools,
)
if tool_err is not None:
raise InvalidRequestError(tool_err)
if self._settings.enable_web_server_tools and is_web_server_tool_request(
routed_request
):
from .web_tools.streaming import stream_web_server_tool_response
input_tokens = self._token_counter(
routed_request.messages, routed_request.system, routed_request.tools
)
logger.info("Optimization: Handling Anthropic web server tool")
egress = WebFetchEgressPolicy(
allow_private_network_targets=self._settings.web_fetch_allow_private_networks,
allowed_schemes=self._settings.web_fetch_allowed_scheme_set(),
)
return anthropic_sse_streaming_response(
stream_web_server_tool_response(
routed_request,
input_tokens=input_tokens,
web_fetch_egress=egress,
verbose_client_errors=self._settings.log_api_error_tracebacks,
),
)
optimized = try_optimizations(routed_request, self._settings)
if optimized is not None:
return optimized
provider = self._provider_getter(resolved.provider_id)
provider.preflight_stream(
routed_request,
thinking_enabled=resolved.thinking_enabled,
)
session_id = _get_session_id(request)
self._session_tracker.track_request_sync(session_id, resolved.provider_id)
request_id = f"req_{uuid.uuid4().hex[:12]}"
logger.info(
"API_REQUEST: request_id={} model={} messages={}",
request_id,
routed_request.model,
len(routed_request.messages),
)
input_tokens = self._token_counter(
routed_request.messages, routed_request.system, routed_request.tools
)
return anthropic_sse_streaming_response(
provider.stream_response(
routed_request,
input_tokens=input_tokens,
request_id=request_id,
thinking_enabled=resolved.thinking_enabled,
),
)
async def _stream_with_fallbacks(
self,
request: Request,
candidates: list[ResolvedModel],
request_data: MessagesRequest,
) -> AsyncIterator[str]:
"""Iterate through candidates until one succeeds or all fail."""
last_exc: Exception | None = None
for i, resolved in enumerate(candidates):
try:
# Pre-check: skip candidates that are currently rate limited or unhealthy
from providers.rate_limit import GlobalRateLimiter
limiter = GlobalRateLimiter.get_scoped_instance(resolved.provider_id)
if limiter.is_blocked() and resolved.provider_id != "zen":
# Silently skip — no failure penalty for temporary rate limit
logger.debug(
"Skipping blocked provider '{}' (no penalty)",
resolved.provider_id,
)
continue
# Check model health (recent failures)
if not limiter.is_healthy(resolved.provider_model_ref):
logger.warning(
"Provider '{}' has recent failures, skipping to next candidate...",
resolved.provider_model_ref,
)
last_exc = Exception("Recent failures")
continue
provider = self._provider_getter(resolved.provider_id)
routed_request = request_data.model_copy(deep=True)
routed_request.model = resolved.provider_model
provider.preflight_stream(
routed_request,
thinking_enabled=resolved.thinking_enabled,
)
session_id = _get_session_id(request)
self._session_tracker.track_request_sync(
session_id, resolved.provider_id
)
request_id = f"req_{uuid.uuid4().hex[:12]}"
logger.info(
"API_REQUEST (auto fallback {}/{}): request_id={} provider={} model={}",
i + 1,
len(candidates),
request_id,
resolved.provider_id,
resolved.provider_model,
)
input_tokens = self._token_counter(
routed_request.messages, routed_request.system, routed_request.tools
)
# Attempt to stream from this provider.
async for event in provider.stream_response(
routed_request,
input_tokens=input_tokens,
request_id=request_id,
thinking_enabled=resolved.thinking_enabled,
):
yield event
# CRITICAL: If we have yielded even one event, we have committed to this provider.
# We must not fallback to another candidate mid-stream.
return # Success, exit the fallback loop.
except (RateLimitError, OverloadedError) as e:
logger.warning(
"Provider '{}' is rate limited or overloaded ({}). Trying next candidate...",
resolved.provider_id,
e.status_code,
)
limiter.record_failure(resolved.provider_model_ref)
last_exc = e
continue
except TimeoutError as e:
# Timeout = slow model, try next candidate for faster response
logger.warning(
"Provider '{}' timed out ({}). Trying next candidate...",
resolved.provider_id,
type(e).__name__,
)
limiter.record_failure(resolved.provider_model_ref)
last_exc = e
continue
except Exception as e:
# Check if it's a transient error that should trigger fallback
error_str = str(e).lower()
is_transient = any(
kw in error_str
for kw in [
"timeout",
"connection",
"refused",
"reset",
"unavailable",
"service",
]
)
if is_transient:
logger.warning(
"Provider '{}' failed with transient error ({}): {}. Trying next candidate...",
resolved.provider_id,
type(e).__name__,
e,
)
limiter.record_failure(resolved.provider_model_ref)
last_exc = e
continue
logger.error(
"Provider '{}' failed with unexpected error: {}. Trying next candidate...",
resolved.provider_id,
e,
)
last_exc = e
continue
err_msg = str(last_exc) if last_exc else "No candidates succeeded"
yield format_sse_event(
"error",
{
"type": "error",
"error": {
"type": "api_error",
"message": f"All fallback candidates failed: {err_msg}",
},
},
)
if last_exc:
raise last_exc
raise InvalidRequestError("No candidates succeeded")
def count_tokens(self, request_data: TokenCountRequest) -> TokenCountResponse:
"""Count tokens for a request after applying configured model routing."""
request_id = f"req_{uuid.uuid4().hex[:12]}"
with logger.contextualize(request_id=request_id):
try:
_require_non_empty_messages(request_data.messages)
routed = self._model_router.resolve_token_count_request(request_data)
tokens = self._token_counter(
routed.request.messages, routed.request.system, routed.request.tools
)
logger.info(
"COUNT_TOKENS: request_id={} model={} messages={} input_tokens={}",
request_id,
routed.request.model,
len(routed.request.messages),
tokens,
)
return TokenCountResponse(input_tokens=tokens)
except ProviderError:
raise
except Exception as e:
_log_unexpected_service_exception(
self._settings,
e,
context="COUNT_TOKENS_ERROR",
request_id=request_id,
)
raise HTTPException(
status_code=_http_status_for_unexpected_service_exception(e),
detail=get_user_facing_error_message(e),
) from e