Spaces:
Running
Running
| """FastAPI route handlers.""" | |
| from fastapi import APIRouter, Depends, HTTPException, Request, Response | |
| from fastapi.responses import HTMLResponse | |
| from loguru import logger | |
| from starlette.templating import Jinja2Templates | |
| from config.settings import Settings | |
| from core.anthropic import get_token_count | |
| from providers.nvidia_nim import metrics as nvidia_nim_metrics | |
| from providers.registry import ProviderRegistry | |
| from . import dependencies | |
| from .dependencies import get_settings, require_api_key | |
| from .gateway_model_ids import gateway_model_id, no_thinking_gateway_model_id | |
| from .models.anthropic import MessagesRequest, TokenCountRequest | |
| from .models.responses import ModelResponse, ModelsListResponse | |
| from .services import ClaudeProxyService | |
| router = APIRouter() | |
| templates = Jinja2Templates(directory="templates") | |
| DISCOVERED_MODEL_CREATED_AT = "1970-01-01T00:00:00Z" | |
| # The proxy advertises a curated set of provider-backed models. Replace | |
| # the previous hardcoded Claude model list with the requested NVIDIA- | |
| # compatible models so clients only see those options. | |
| REQUESTED_PROVIDER_MODELS = [ | |
| # Zen/OpenCode free models | |
| "zen/minimax-m2.5-free", | |
| "zen/big-pickle", | |
| "zen/ring-2.6-1t-free", | |
| "zen/nemotron-3-super-free", | |
| # NVIDIA NIM models (top 5) | |
| "nvidia_nim/stepfun-ai/step-3.5-flash", | |
| "nvidia_nim/qwen/qwen3-coder-480b-a35b-instruct", | |
| "nvidia_nim/mistralai/mistral-large-3-675b-instruct-2512", | |
| "nvidia_nim/z-ai/glm4.7", | |
| "nvidia_nim/minimaxai/minimax-m2.7", | |
| # Cerebras models (key only has access to llama3.1-8b currently) | |
| # qwen-3-235b-a22b-instruct-2507 exists but is rate-limited | |
| # zai-glm-4.7 and gpt-oss-120b are not accessible with current key | |
| "cerebras/llama3.1-8b", | |
| # Silicon Flow models (top 5 for free tier) | |
| # DeepSeek-V3 - strong MoE model | |
| "silicon/deepseek-ai/DeepSeek-V3", | |
| # Qwen3-Coder-30B-A3B - coding specialized | |
| "silicon/Qwen/Qwen3-Coder-30B-A3B-Instruct", | |
| # Qwen3.6-35B-A3B - multimodal, 262K context | |
| "silicon/Qwen/Qwen3.6-35B-A3B", | |
| # Qwen2.5-72B - strong general purpose, 128K context | |
| "silicon/Qwen/Qwen2.5-72B-Instruct", | |
| # Qwen3-32B - reasoning model | |
| "silicon/Qwen/Qwen3-32B", | |
| # Groq models (ultra fast inference) | |
| "groq/llama-3.3-70b-versatile", | |
| "groq/llama-3.1-8b-instant", | |
| "groq/qwen3-32b", | |
| ] | |
| def get_proxy_service( | |
| request: Request, | |
| settings: Settings = Depends(get_settings), | |
| ) -> ClaudeProxyService: | |
| """Build the request service for route handlers.""" | |
| return ClaudeProxyService( | |
| settings, | |
| provider_getter=lambda provider_type: dependencies.resolve_provider( | |
| provider_type, app=request.app, settings=settings | |
| ), | |
| token_counter=get_token_count, | |
| ) | |
| def _probe_response(allow: str) -> Response: | |
| """Return an empty success response for compatibility probes.""" | |
| return Response(status_code=204, headers={"Allow": allow}) | |
| def _discovered_model_response(model_id: str, *, display_name: str) -> ModelResponse: | |
| return ModelResponse( | |
| id=model_id, | |
| display_name=display_name, | |
| created_at=DISCOVERED_MODEL_CREATED_AT, | |
| ) | |
| def _append_unique_model( | |
| models: list[ModelResponse], seen: set[str], model: ModelResponse | |
| ) -> None: | |
| if model.id in seen: | |
| return | |
| seen.add(model.id) | |
| models.append(model) | |
| def _append_provider_model_variants( | |
| models: list[ModelResponse], | |
| seen: set[str], | |
| provider_model_ref: str, | |
| *, | |
| supports_thinking: bool | None = None, | |
| ) -> None: | |
| if supports_thinking is not False: | |
| _append_unique_model( | |
| models, | |
| seen, | |
| _discovered_model_response( | |
| gateway_model_id(provider_model_ref), | |
| display_name=provider_model_ref, | |
| ), | |
| ) | |
| _append_unique_model( | |
| models, | |
| seen, | |
| _discovered_model_response( | |
| no_thinking_gateway_model_id(provider_model_ref), | |
| display_name=f"{provider_model_ref} (no thinking)", | |
| ), | |
| ) | |
| def _build_models_list_response( | |
| settings: Settings, provider_registry: ProviderRegistry | None | |
| ) -> ModelsListResponse: | |
| models: list[ModelResponse] = [] | |
| seen: set[str] = set() | |
| # Advertise only the requested provider models (no Claude models, no registry auto-discovery). | |
| # Each ref is added with both thinking and no-thinking variants. | |
| for provider_ref in REQUESTED_PROVIDER_MODELS: | |
| # If the ref already contains a provider prefix, use it as-is; | |
| # otherwise assume it belongs to the NVIDIA NIM provider. | |
| ref = provider_ref if "/" in provider_ref else f"nvidia_nim/{provider_ref}" | |
| supports_thinking = None | |
| if provider_registry is not None: | |
| # model_id for registry lookups should be provider-prefixed | |
| provider, model_id = ( | |
| ref.split("/", 1) if "/" in ref else ("nvidia_nim", ref) | |
| ) | |
| supports_thinking = provider_registry.cached_model_supports_thinking( | |
| provider, model_id | |
| ) | |
| _append_provider_model_variants( | |
| models, seen, ref, supports_thinking=supports_thinking | |
| ) | |
| # Add a virtual `auto` model that maps to the configured MODEL and enables | |
| # automatic fallback behavior when used by clients. | |
| _append_unique_model( | |
| models, | |
| seen, | |
| ModelResponse( | |
| id=gateway_model_id("auto"), | |
| display_name="auto (use configured fallbacks)", | |
| created_at=DISCOVERED_MODEL_CREATED_AT, | |
| ), | |
| ) | |
| # Filter out any residual Claude-branded models so the proxy advertises | |
| # only the provider-backed models requested by the user. | |
| filtered = [ | |
| m | |
| for m in models | |
| if "claude" not in (m.id or "").lower() | |
| and "claude" not in (m.display_name or "").lower() | |
| ] | |
| # Ensure `auto` model remains available even if filtering removed others. | |
| if not any(m.id == gateway_model_id("auto") for m in filtered): | |
| filtered.append( | |
| ModelResponse( | |
| id=gateway_model_id("auto"), | |
| display_name="auto (use configured fallbacks)", | |
| created_at=DISCOVERED_MODEL_CREATED_AT, | |
| ) | |
| ) | |
| return ModelsListResponse( | |
| data=filtered, | |
| first_id=filtered[0].id if filtered else None, | |
| has_more=False, | |
| last_id=filtered[-1].id if filtered else None, | |
| ) | |
| # ============================================================================= | |
| # Routes | |
| # ============================================================================= | |
| async def create_message( | |
| request: Request, | |
| request_data: MessagesRequest, | |
| service: ClaudeProxyService = Depends(get_proxy_service), | |
| _auth=Depends(require_api_key), | |
| ): | |
| """Create a message (always streaming).""" | |
| return service.create_message(request, request_data) | |
| async def probe_messages(_auth=Depends(require_api_key)): | |
| """Respond to Claude compatibility probes for the messages endpoint.""" | |
| return _probe_response("POST, HEAD, OPTIONS") | |
| async def count_tokens( | |
| request_data: TokenCountRequest, | |
| service: ClaudeProxyService = Depends(get_proxy_service), | |
| _auth=Depends(require_api_key), | |
| ): | |
| """Count tokens for a request.""" | |
| return service.count_tokens(request_data) | |
| async def probe_count_tokens(_auth=Depends(require_api_key)): | |
| """Respond to Claude compatibility probes for the token count endpoint.""" | |
| return _probe_response("POST, HEAD, OPTIONS") | |
| async def root(request: Request, _auth=Depends(require_api_key)): | |
| """Root endpoint - displays admin dashboard.""" | |
| from .admin import _get_admin_data | |
| data = _get_admin_data() | |
| return templates.TemplateResponse("admin.html", {"request": request, **data}) | |
| async def probe_root(_auth=Depends(require_api_key)): | |
| """Respond to compatibility probes for the root endpoint.""" | |
| return _probe_response("GET, HEAD, OPTIONS") | |
| async def health(): | |
| """Health check endpoint.""" | |
| return {"status": "healthy"} | |
| async def probe_health(): | |
| """Respond to compatibility probes for the health endpoint.""" | |
| return _probe_response("GET, HEAD, OPTIONS") | |
| async def list_models( | |
| request: Request, | |
| settings: Settings = Depends(get_settings), | |
| _auth=Depends(require_api_key), | |
| ): | |
| """List the model ids this proxy advertises to Claude-compatible clients.""" | |
| registry = getattr(request.app.state, "provider_registry", None) | |
| provider_registry = registry if isinstance(registry, ProviderRegistry) else None | |
| return _build_models_list_response(settings, provider_registry) | |
| async def stop_cli(request: Request, _auth=Depends(require_api_key)): | |
| """Stop all CLI sessions and pending tasks.""" | |
| handler = getattr(request.app.state, "message_handler", None) | |
| if not handler: | |
| # Fallback if messaging not initialized | |
| cli_manager = getattr(request.app.state, "cli_manager", None) | |
| if cli_manager: | |
| await cli_manager.stop_all() | |
| logger.info("STOP_CLI: source=cli_manager cancelled_count=N/A") | |
| return {"status": "stopped", "source": "cli_manager"} | |
| raise HTTPException(status_code=503, detail="Messaging system not initialized") | |
| count = await handler.stop_all_tasks() | |
| logger.info("STOP_CLI: source=handler cancelled_count={}", count) | |
| return {"status": "stopped", "cancelled_count": count} | |
| async def admin_fallbacks(_auth=Depends(require_api_key)): | |
| """Admin endpoint exposing NVIDIA NIM fallback metrics. | |
| Protected by the same API key as other endpoints. | |
| """ | |
| try: | |
| data = nvidia_nim_metrics.snapshot() | |
| except Exception as e: | |
| logger.warning("ADMIN_FALLBACKS: failed to read metrics: {}", e) | |
| raise HTTPException(status_code=500, detail="failed to read metrics") | |
| return {"provider": "nvidia_nim", "fallbacks": data} | |