Spaces:
Running
Running
File size: 10,443 Bytes
0157ac7 1985e64 0157ac7 1985e64 0157ac7 ef123a8 0157ac7 1985e64 0157ac7 04fcbd7 d6a1875 04fcbd7 ef22b95 d6a1875 0157ac7 ef22b95 0ba585f 5bba595 24b9325 5bba595 0223890 5bba595 c9c8b95 24b9325 98fdd46 0157ac7 ef123a8 0157ac7 ef123a8 0157ac7 f56589d 0157ac7 f56589d 0157ac7 1985e64 0157ac7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 | """FastAPI route handlers."""
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from fastapi.responses import HTMLResponse
from loguru import logger
from starlette.templating import Jinja2Templates
from config.settings import Settings
from core.anthropic import get_token_count
from providers.nvidia_nim import metrics as nvidia_nim_metrics
from providers.registry import ProviderRegistry
from . import dependencies
from .dependencies import get_settings, require_api_key
from .gateway_model_ids import gateway_model_id, no_thinking_gateway_model_id
from .models.anthropic import MessagesRequest, TokenCountRequest
from .models.responses import ModelResponse, ModelsListResponse
from .services import ClaudeProxyService
router = APIRouter()
templates = Jinja2Templates(directory="templates")
DISCOVERED_MODEL_CREATED_AT = "1970-01-01T00:00:00Z"
# The proxy advertises a curated set of provider-backed models. Replace
# the previous hardcoded Claude model list with the requested NVIDIA-
# compatible models so clients only see those options.
REQUESTED_PROVIDER_MODELS = [
# Zen/OpenCode free models
"zen/minimax-m2.5-free",
"zen/big-pickle",
"zen/ring-2.6-1t-free",
"zen/nemotron-3-super-free",
# NVIDIA NIM models (top 5)
"nvidia_nim/stepfun-ai/step-3.5-flash",
"nvidia_nim/qwen/qwen3-coder-480b-a35b-instruct",
"nvidia_nim/mistralai/mistral-large-3-675b-instruct-2512",
"nvidia_nim/z-ai/glm4.7",
"nvidia_nim/minimaxai/minimax-m2.7",
# Cerebras models (key only has access to llama3.1-8b currently)
# qwen-3-235b-a22b-instruct-2507 exists but is rate-limited
# zai-glm-4.7 and gpt-oss-120b are not accessible with current key
"cerebras/llama3.1-8b",
# Silicon Flow models (top 5 for free tier)
# DeepSeek-V3 - strong MoE model
"silicon/deepseek-ai/DeepSeek-V3",
# Qwen3-Coder-30B-A3B - coding specialized
"silicon/Qwen/Qwen3-Coder-30B-A3B-Instruct",
# Qwen3.6-35B-A3B - multimodal, 262K context
"silicon/Qwen/Qwen3.6-35B-A3B",
# Qwen2.5-72B - strong general purpose, 128K context
"silicon/Qwen/Qwen2.5-72B-Instruct",
# Qwen3-32B - reasoning model
"silicon/Qwen/Qwen3-32B",
# Groq models (ultra fast inference)
"groq/llama-3.3-70b-versatile",
"groq/llama-3.1-8b-instant",
"groq/qwen3-32b",
]
def get_proxy_service(
request: Request,
settings: Settings = Depends(get_settings),
) -> ClaudeProxyService:
"""Build the request service for route handlers."""
return ClaudeProxyService(
settings,
provider_getter=lambda provider_type: dependencies.resolve_provider(
provider_type, app=request.app, settings=settings
),
token_counter=get_token_count,
)
def _probe_response(allow: str) -> Response:
"""Return an empty success response for compatibility probes."""
return Response(status_code=204, headers={"Allow": allow})
def _discovered_model_response(model_id: str, *, display_name: str) -> ModelResponse:
return ModelResponse(
id=model_id,
display_name=display_name,
created_at=DISCOVERED_MODEL_CREATED_AT,
)
def _append_unique_model(
models: list[ModelResponse], seen: set[str], model: ModelResponse
) -> None:
if model.id in seen:
return
seen.add(model.id)
models.append(model)
def _append_provider_model_variants(
models: list[ModelResponse],
seen: set[str],
provider_model_ref: str,
*,
supports_thinking: bool | None = None,
) -> None:
if supports_thinking is not False:
_append_unique_model(
models,
seen,
_discovered_model_response(
gateway_model_id(provider_model_ref),
display_name=provider_model_ref,
),
)
_append_unique_model(
models,
seen,
_discovered_model_response(
no_thinking_gateway_model_id(provider_model_ref),
display_name=f"{provider_model_ref} (no thinking)",
),
)
def _build_models_list_response(
settings: Settings, provider_registry: ProviderRegistry | None
) -> ModelsListResponse:
models: list[ModelResponse] = []
seen: set[str] = set()
# Advertise only the requested provider models (no Claude models, no registry auto-discovery).
# Each ref is added with both thinking and no-thinking variants.
for provider_ref in REQUESTED_PROVIDER_MODELS:
# If the ref already contains a provider prefix, use it as-is;
# otherwise assume it belongs to the NVIDIA NIM provider.
ref = provider_ref if "/" in provider_ref else f"nvidia_nim/{provider_ref}"
supports_thinking = None
if provider_registry is not None:
# model_id for registry lookups should be provider-prefixed
provider, model_id = (
ref.split("/", 1) if "/" in ref else ("nvidia_nim", ref)
)
supports_thinking = provider_registry.cached_model_supports_thinking(
provider, model_id
)
_append_provider_model_variants(
models, seen, ref, supports_thinking=supports_thinking
)
# Add a virtual `auto` model that maps to the configured MODEL and enables
# automatic fallback behavior when used by clients.
_append_unique_model(
models,
seen,
ModelResponse(
id=gateway_model_id("auto"),
display_name="auto (use configured fallbacks)",
created_at=DISCOVERED_MODEL_CREATED_AT,
),
)
# Filter out any residual Claude-branded models so the proxy advertises
# only the provider-backed models requested by the user.
filtered = [
m
for m in models
if "claude" not in (m.id or "").lower()
and "claude" not in (m.display_name or "").lower()
]
# Ensure `auto` model remains available even if filtering removed others.
if not any(m.id == gateway_model_id("auto") for m in filtered):
filtered.append(
ModelResponse(
id=gateway_model_id("auto"),
display_name="auto (use configured fallbacks)",
created_at=DISCOVERED_MODEL_CREATED_AT,
)
)
return ModelsListResponse(
data=filtered,
first_id=filtered[0].id if filtered else None,
has_more=False,
last_id=filtered[-1].id if filtered else None,
)
# =============================================================================
# Routes
# =============================================================================
@router.post("/v1/messages")
async def create_message(
request: Request,
request_data: MessagesRequest,
service: ClaudeProxyService = Depends(get_proxy_service),
_auth=Depends(require_api_key),
):
"""Create a message (always streaming)."""
return service.create_message(request, request_data)
@router.api_route("/v1/messages", methods=["HEAD", "OPTIONS"])
async def probe_messages(_auth=Depends(require_api_key)):
"""Respond to Claude compatibility probes for the messages endpoint."""
return _probe_response("POST, HEAD, OPTIONS")
@router.post("/v1/messages/count_tokens")
async def count_tokens(
request_data: TokenCountRequest,
service: ClaudeProxyService = Depends(get_proxy_service),
_auth=Depends(require_api_key),
):
"""Count tokens for a request."""
return service.count_tokens(request_data)
@router.api_route("/v1/messages/count_tokens", methods=["HEAD", "OPTIONS"])
async def probe_count_tokens(_auth=Depends(require_api_key)):
"""Respond to Claude compatibility probes for the token count endpoint."""
return _probe_response("POST, HEAD, OPTIONS")
@router.get("/", response_class=HTMLResponse)
async def root(request: Request, _auth=Depends(require_api_key)):
"""Root endpoint - displays admin dashboard."""
from .admin import _get_admin_data
data = _get_admin_data()
return templates.TemplateResponse("admin.html", {"request": request, **data})
@router.api_route("/", methods=["HEAD", "OPTIONS"])
async def probe_root(_auth=Depends(require_api_key)):
"""Respond to compatibility probes for the root endpoint."""
return _probe_response("GET, HEAD, OPTIONS")
@router.get("/health")
async def health():
"""Health check endpoint."""
return {"status": "healthy"}
@router.api_route("/health", methods=["HEAD", "OPTIONS"])
async def probe_health():
"""Respond to compatibility probes for the health endpoint."""
return _probe_response("GET, HEAD, OPTIONS")
@router.get("/v1/models", response_model=ModelsListResponse)
async def list_models(
request: Request,
settings: Settings = Depends(get_settings),
_auth=Depends(require_api_key),
):
"""List the model ids this proxy advertises to Claude-compatible clients."""
registry = getattr(request.app.state, "provider_registry", None)
provider_registry = registry if isinstance(registry, ProviderRegistry) else None
return _build_models_list_response(settings, provider_registry)
@router.post("/stop")
async def stop_cli(request: Request, _auth=Depends(require_api_key)):
"""Stop all CLI sessions and pending tasks."""
handler = getattr(request.app.state, "message_handler", None)
if not handler:
# Fallback if messaging not initialized
cli_manager = getattr(request.app.state, "cli_manager", None)
if cli_manager:
await cli_manager.stop_all()
logger.info("STOP_CLI: source=cli_manager cancelled_count=N/A")
return {"status": "stopped", "source": "cli_manager"}
raise HTTPException(status_code=503, detail="Messaging system not initialized")
count = await handler.stop_all_tasks()
logger.info("STOP_CLI: source=handler cancelled_count={}", count)
return {"status": "stopped", "cancelled_count": count}
@router.get("/admin/fallbacks")
async def admin_fallbacks(_auth=Depends(require_api_key)):
"""Admin endpoint exposing NVIDIA NIM fallback metrics.
Protected by the same API key as other endpoints.
"""
try:
data = nvidia_nim_metrics.snapshot()
except Exception as e:
logger.warning("ADMIN_FALLBACKS: failed to read metrics: {}", e)
raise HTTPException(status_code=500, detail="failed to read metrics")
return {"provider": "nvidia_nim", "fallbacks": data}
|