File size: 10,443 Bytes
0157ac7
 
 
1985e64
0157ac7
1985e64
0157ac7
 
 
ef123a8
0157ac7
 
 
 
 
 
 
 
 
 
 
1985e64
 
0157ac7
 
 
 
 
 
 
04fcbd7
d6a1875
04fcbd7
 
 
ef22b95
d6a1875
0157ac7
 
 
ef22b95
0ba585f
 
 
 
5bba595
24b9325
 
 
 
5bba595
0223890
5bba595
c9c8b95
24b9325
 
98fdd46
 
 
 
0157ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef123a8
 
 
 
 
 
 
 
 
0157ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef123a8
 
0157ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f56589d
0157ac7
 
 
 
 
f56589d
0157ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1985e64
 
 
 
 
 
 
 
0157ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
"""FastAPI route handlers."""

from fastapi import APIRouter, Depends, HTTPException, Request, Response
from fastapi.responses import HTMLResponse
from loguru import logger
from starlette.templating import Jinja2Templates

from config.settings import Settings
from core.anthropic import get_token_count
from providers.nvidia_nim import metrics as nvidia_nim_metrics
from providers.registry import ProviderRegistry

from . import dependencies
from .dependencies import get_settings, require_api_key
from .gateway_model_ids import gateway_model_id, no_thinking_gateway_model_id
from .models.anthropic import MessagesRequest, TokenCountRequest
from .models.responses import ModelResponse, ModelsListResponse
from .services import ClaudeProxyService

router = APIRouter()

templates = Jinja2Templates(directory="templates")

DISCOVERED_MODEL_CREATED_AT = "1970-01-01T00:00:00Z"


# The proxy advertises a curated set of provider-backed models. Replace
# the previous hardcoded Claude model list with the requested NVIDIA-
# compatible models so clients only see those options.
REQUESTED_PROVIDER_MODELS = [
    # Zen/OpenCode free models
    "zen/minimax-m2.5-free",
    "zen/big-pickle",
    "zen/ring-2.6-1t-free",
    "zen/nemotron-3-super-free",
    # NVIDIA NIM models (top 5)
    "nvidia_nim/stepfun-ai/step-3.5-flash",
    "nvidia_nim/qwen/qwen3-coder-480b-a35b-instruct",
    "nvidia_nim/mistralai/mistral-large-3-675b-instruct-2512",
    "nvidia_nim/z-ai/glm4.7",
    "nvidia_nim/minimaxai/minimax-m2.7",
    # Cerebras models (key only has access to llama3.1-8b currently)
    # qwen-3-235b-a22b-instruct-2507 exists but is rate-limited
    # zai-glm-4.7 and gpt-oss-120b are not accessible with current key
    "cerebras/llama3.1-8b",
    # Silicon Flow models (top 5 for free tier)
    # DeepSeek-V3 - strong MoE model
    "silicon/deepseek-ai/DeepSeek-V3",
    # Qwen3-Coder-30B-A3B - coding specialized
    "silicon/Qwen/Qwen3-Coder-30B-A3B-Instruct",
    # Qwen3.6-35B-A3B - multimodal, 262K context
    "silicon/Qwen/Qwen3.6-35B-A3B",
    # Qwen2.5-72B - strong general purpose, 128K context
    "silicon/Qwen/Qwen2.5-72B-Instruct",
    # Qwen3-32B - reasoning model
    "silicon/Qwen/Qwen3-32B",
    # Groq models (ultra fast inference)
    "groq/llama-3.3-70b-versatile",
    "groq/llama-3.1-8b-instant",
    "groq/qwen3-32b",
]


def get_proxy_service(
    request: Request,
    settings: Settings = Depends(get_settings),
) -> ClaudeProxyService:
    """Build the request service for route handlers."""
    return ClaudeProxyService(
        settings,
        provider_getter=lambda provider_type: dependencies.resolve_provider(
            provider_type, app=request.app, settings=settings
        ),
        token_counter=get_token_count,
    )


def _probe_response(allow: str) -> Response:
    """Return an empty success response for compatibility probes."""
    return Response(status_code=204, headers={"Allow": allow})


def _discovered_model_response(model_id: str, *, display_name: str) -> ModelResponse:
    return ModelResponse(
        id=model_id,
        display_name=display_name,
        created_at=DISCOVERED_MODEL_CREATED_AT,
    )


def _append_unique_model(
    models: list[ModelResponse], seen: set[str], model: ModelResponse
) -> None:
    if model.id in seen:
        return
    seen.add(model.id)
    models.append(model)


def _append_provider_model_variants(
    models: list[ModelResponse],
    seen: set[str],
    provider_model_ref: str,
    *,
    supports_thinking: bool | None = None,
) -> None:
    if supports_thinking is not False:
        _append_unique_model(
            models,
            seen,
            _discovered_model_response(
                gateway_model_id(provider_model_ref),
                display_name=provider_model_ref,
            ),
        )
    _append_unique_model(
        models,
        seen,
        _discovered_model_response(
            no_thinking_gateway_model_id(provider_model_ref),
            display_name=f"{provider_model_ref} (no thinking)",
        ),
    )


def _build_models_list_response(
    settings: Settings, provider_registry: ProviderRegistry | None
) -> ModelsListResponse:
    models: list[ModelResponse] = []
    seen: set[str] = set()

    # Advertise only the requested provider models (no Claude models, no registry auto-discovery).
    # Each ref is added with both thinking and no-thinking variants.
    for provider_ref in REQUESTED_PROVIDER_MODELS:
        # If the ref already contains a provider prefix, use it as-is;
        # otherwise assume it belongs to the NVIDIA NIM provider.
        ref = provider_ref if "/" in provider_ref else f"nvidia_nim/{provider_ref}"
        supports_thinking = None
        if provider_registry is not None:
            # model_id for registry lookups should be provider-prefixed
            provider, model_id = (
                ref.split("/", 1) if "/" in ref else ("nvidia_nim", ref)
            )
            supports_thinking = provider_registry.cached_model_supports_thinking(
                provider, model_id
            )
        _append_provider_model_variants(
            models, seen, ref, supports_thinking=supports_thinking
        )

    # Add a virtual `auto` model that maps to the configured MODEL and enables
    # automatic fallback behavior when used by clients.
    _append_unique_model(
        models,
        seen,
        ModelResponse(
            id=gateway_model_id("auto"),
            display_name="auto (use configured fallbacks)",
            created_at=DISCOVERED_MODEL_CREATED_AT,
        ),
    )

    # Filter out any residual Claude-branded models so the proxy advertises
    # only the provider-backed models requested by the user.
    filtered = [
        m
        for m in models
        if "claude" not in (m.id or "").lower()
        and "claude" not in (m.display_name or "").lower()
    ]
    # Ensure `auto` model remains available even if filtering removed others.
    if not any(m.id == gateway_model_id("auto") for m in filtered):
        filtered.append(
            ModelResponse(
                id=gateway_model_id("auto"),
                display_name="auto (use configured fallbacks)",
                created_at=DISCOVERED_MODEL_CREATED_AT,
            )
        )

    return ModelsListResponse(
        data=filtered,
        first_id=filtered[0].id if filtered else None,
        has_more=False,
        last_id=filtered[-1].id if filtered else None,
    )


# =============================================================================
# Routes
# =============================================================================
@router.post("/v1/messages")
async def create_message(
    request: Request,
    request_data: MessagesRequest,
    service: ClaudeProxyService = Depends(get_proxy_service),
    _auth=Depends(require_api_key),
):
    """Create a message (always streaming)."""
    return service.create_message(request, request_data)


@router.api_route("/v1/messages", methods=["HEAD", "OPTIONS"])
async def probe_messages(_auth=Depends(require_api_key)):
    """Respond to Claude compatibility probes for the messages endpoint."""
    return _probe_response("POST, HEAD, OPTIONS")


@router.post("/v1/messages/count_tokens")
async def count_tokens(
    request_data: TokenCountRequest,
    service: ClaudeProxyService = Depends(get_proxy_service),
    _auth=Depends(require_api_key),
):
    """Count tokens for a request."""
    return service.count_tokens(request_data)


@router.api_route("/v1/messages/count_tokens", methods=["HEAD", "OPTIONS"])
async def probe_count_tokens(_auth=Depends(require_api_key)):
    """Respond to Claude compatibility probes for the token count endpoint."""
    return _probe_response("POST, HEAD, OPTIONS")


@router.get("/", response_class=HTMLResponse)
async def root(request: Request, _auth=Depends(require_api_key)):
    """Root endpoint - displays admin dashboard."""
    from .admin import _get_admin_data

    data = _get_admin_data()

    return templates.TemplateResponse("admin.html", {"request": request, **data})


@router.api_route("/", methods=["HEAD", "OPTIONS"])
async def probe_root(_auth=Depends(require_api_key)):
    """Respond to compatibility probes for the root endpoint."""
    return _probe_response("GET, HEAD, OPTIONS")


@router.get("/health")
async def health():
    """Health check endpoint."""
    return {"status": "healthy"}


@router.api_route("/health", methods=["HEAD", "OPTIONS"])
async def probe_health():
    """Respond to compatibility probes for the health endpoint."""
    return _probe_response("GET, HEAD, OPTIONS")


@router.get("/v1/models", response_model=ModelsListResponse)
async def list_models(
    request: Request,
    settings: Settings = Depends(get_settings),
    _auth=Depends(require_api_key),
):
    """List the model ids this proxy advertises to Claude-compatible clients."""
    registry = getattr(request.app.state, "provider_registry", None)
    provider_registry = registry if isinstance(registry, ProviderRegistry) else None
    return _build_models_list_response(settings, provider_registry)


@router.post("/stop")
async def stop_cli(request: Request, _auth=Depends(require_api_key)):
    """Stop all CLI sessions and pending tasks."""
    handler = getattr(request.app.state, "message_handler", None)
    if not handler:
        # Fallback if messaging not initialized
        cli_manager = getattr(request.app.state, "cli_manager", None)
        if cli_manager:
            await cli_manager.stop_all()
            logger.info("STOP_CLI: source=cli_manager cancelled_count=N/A")
            return {"status": "stopped", "source": "cli_manager"}
        raise HTTPException(status_code=503, detail="Messaging system not initialized")

    count = await handler.stop_all_tasks()
    logger.info("STOP_CLI: source=handler cancelled_count={}", count)
    return {"status": "stopped", "cancelled_count": count}


@router.get("/admin/fallbacks")
async def admin_fallbacks(_auth=Depends(require_api_key)):
    """Admin endpoint exposing NVIDIA NIM fallback metrics.

    Protected by the same API key as other endpoints.
    """
    try:
        data = nvidia_nim_metrics.snapshot()
    except Exception as e:
        logger.warning("ADMIN_FALLBACKS: failed to read metrics: {}", e)
        raise HTTPException(status_code=500, detail="failed to read metrics")
    return {"provider": "nvidia_nim", "fallbacks": data}