Spaces:
Sleeping
Sleeping
| """ | |
| OpenAI-compatible /v1 API Gateway | |
| Proxies to NVIDIA NIM API with streaming always enabled, | |
| function calling support, and per-model system prompts. | |
| Deploy on Hugging Face Spaces (Docker). | |
| Authorization: Bearer connect | |
| """ | |
| import json | |
| import time | |
| import uuid | |
| import asyncio | |
| from typing import Any, AsyncGenerator | |
| import httpx | |
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse, JSONResponse | |
| from pydantic import BaseModel, Field | |
| from system_prompts import SYSTEM_PROMPTS, MODEL_MAP, REVERSE_MODEL_MAP, EXTRA_BODY_MODELS | |
| # --------------------------------------------------------------------------- | |
| # Config | |
| # --------------------------------------------------------------------------- | |
| NVIDIA_BASE_URL = "https://integrate.api.nvidia.com/v1" | |
| NVIDIA_API_KEY = "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" | |
| GATEWAY_API_KEY = "connect" | |
| # --------------------------------------------------------------------------- | |
| # App | |
| # --------------------------------------------------------------------------- | |
| app = FastAPI( | |
| title="AI Gateway", | |
| description="OpenAI-compatible gateway to NVIDIA NIM models", | |
| version="1.0.0", | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Auth | |
| # --------------------------------------------------------------------------- | |
| def verify_api_key(request: Request) -> None: | |
| auth = request.headers.get("Authorization", "") | |
| if not auth.startswith("Bearer "): | |
| raise HTTPException(status_code=401, detail="Missing Bearer token") | |
| token = auth.removeprefix("Bearer ").strip() | |
| if token != GATEWAY_API_KEY: | |
| raise HTTPException(status_code=401, detail="Invalid API key") | |
| # --------------------------------------------------------------------------- | |
| # Pydantic models (OpenAI-compatible) | |
| # --------------------------------------------------------------------------- | |
| class FunctionParameters(BaseModel): | |
| type: str = "object" | |
| properties: dict[str, Any] = {} | |
| required: list[str] = [] | |
| class FunctionDef(BaseModel): | |
| name: str | |
| description: str | None = None | |
| parameters: FunctionParameters | None = None | |
| class Tool(BaseModel): | |
| type: str = "function" | |
| function: FunctionDef | |
| class ToolChoice(BaseModel): | |
| type: str = "function" | |
| function: dict[str, str] | None = None | |
| class Message(BaseModel): | |
| role: str | |
| content: str | list[Any] | None = None | |
| name: str | None = None | |
| tool_calls: list[Any] | None = None | |
| tool_call_id: str | None = None | |
| class ChatCompletionRequest(BaseModel): | |
| model: str | |
| messages: list[Message] | |
| temperature: float | None = None | |
| top_p: float | None = None | |
| max_tokens: int | None = None | |
| tools: list[Tool] | None = None | |
| tool_choice: str | ToolChoice | None = None | |
| # stream is ALWAYS True – ignored if provided, always forced to True | |
| stream: bool = True | |
| stop: list[str] | str | None = None | |
| presence_penalty: float | None = None | |
| frequency_penalty: float | None = None | |
| seed: int | None = None | |
| n: int | None = None | |
| logprobs: bool | None = None | |
| top_logprobs: int | None = None | |
| user: str | None = None | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| def resolve_model(requested: str) -> str: | |
| """Map display name or raw NVIDIA model ID to NVIDIA model ID.""" | |
| if requested in MODEL_MAP: | |
| return MODEL_MAP[requested] | |
| if requested in REVERSE_MODEL_MAP: | |
| return requested # already a raw ID | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Unknown model '{requested}'. Available: {list(MODEL_MAP.keys())}", | |
| ) | |
| def get_display_name(nvidia_id: str) -> str: | |
| return REVERSE_MODEL_MAP.get(nvidia_id, nvidia_id) | |
| def inject_system_prompt(messages: list[Message], display_name: str) -> list[dict]: | |
| """Inject per-model system prompt if not already present.""" | |
| prompt = SYSTEM_PROMPTS.get(display_name) | |
| serialized = [m.model_dump(exclude_none=True) for m in messages] | |
| if prompt: | |
| has_system = any(m["role"] == "system" for m in serialized) | |
| if not has_system: | |
| serialized = [{"role": "system", "content": prompt}] + serialized | |
| return serialized | |
| def build_nvidia_payload(req: ChatCompletionRequest, nvidia_model: str) -> dict: | |
| display = get_display_name(nvidia_model) | |
| messages = inject_system_prompt(req.messages, display) | |
| payload: dict[str, Any] = { | |
| "model": nvidia_model, | |
| "messages": messages, | |
| "stream": True, # ALWAYS TRUE | |
| } | |
| # Optional params | |
| if req.temperature is not None: | |
| payload["temperature"] = req.temperature | |
| if req.top_p is not None: | |
| payload["top_p"] = req.top_p | |
| if req.max_tokens is not None: | |
| payload["max_tokens"] = req.max_tokens | |
| if req.stop is not None: | |
| payload["stop"] = req.stop | |
| if req.presence_penalty is not None: | |
| payload["presence_penalty"] = req.presence_penalty | |
| if req.frequency_penalty is not None: | |
| payload["frequency_penalty"] = req.frequency_penalty | |
| if req.seed is not None: | |
| payload["seed"] = req.seed | |
| if req.n is not None: | |
| payload["n"] = req.n | |
| if req.user is not None: | |
| payload["user"] = req.user | |
| # Function calling / tools | |
| if req.tools: | |
| payload["tools"] = [t.model_dump(exclude_none=True) for t in req.tools] | |
| if req.tool_choice is not None: | |
| if isinstance(req.tool_choice, str): | |
| payload["tool_choice"] = req.tool_choice | |
| else: | |
| payload["tool_choice"] = req.tool_choice.model_dump(exclude_none=True) | |
| # Extra body for specific models (e.g. GLM-4.7 thinking params) | |
| extra = EXTRA_BODY_MODELS.get(nvidia_model, {}) | |
| payload.update(extra) | |
| return payload | |
| # --------------------------------------------------------------------------- | |
| # SSE streaming proxy | |
| # --------------------------------------------------------------------------- | |
| async def stream_nvidia(payload: dict) -> AsyncGenerator[bytes, None]: | |
| headers = { | |
| "Authorization": f"Bearer {NVIDIA_API_KEY}", | |
| "Content-Type": "application/json", | |
| "Accept": "text/event-stream", | |
| } | |
| async with httpx.AsyncClient(timeout=300) as client: | |
| async with client.stream( | |
| "POST", | |
| f"{NVIDIA_BASE_URL}/chat/completions", | |
| headers=headers, | |
| json=payload, | |
| ) as response: | |
| if response.status_code != 200: | |
| body = await response.aread() | |
| error_detail = body.decode(errors="replace") | |
| error_chunk = { | |
| "error": { | |
| "message": f"Upstream error {response.status_code}: {error_detail}", | |
| "type": "upstream_error", | |
| "code": response.status_code, | |
| } | |
| } | |
| yield f"data: {json.dumps(error_chunk)}\n\n".encode() | |
| yield b"data: [DONE]\n\n" | |
| return | |
| async for line in response.aiter_lines(): | |
| if line.startswith("data: "): | |
| yield f"{line}\n\n".encode() | |
| if line == "data: [DONE]": | |
| return | |
| elif line.strip(): | |
| # Pass through any unexpected lines | |
| yield f"data: {line}\n\n".encode() | |
| # --------------------------------------------------------------------------- | |
| # Routes | |
| # --------------------------------------------------------------------------- | |
| async def root(): | |
| return {"status": "ok", "service": "AI Gateway", "version": "1.0.0"} | |
| async def list_models(request: Request): | |
| verify_api_key(request) | |
| now = int(time.time()) | |
| models = [] | |
| for display_name in MODEL_MAP: | |
| models.append({ | |
| "id": display_name, | |
| "object": "model", | |
| "created": now, | |
| "owned_by": "ai-gateway", | |
| }) | |
| return {"object": "list", "data": models} | |
| async def chat_completions(request: Request, req: ChatCompletionRequest): | |
| verify_api_key(request) | |
| nvidia_model = resolve_model(req.model) | |
| payload = build_nvidia_payload(req, nvidia_model) | |
| return StreamingResponse( | |
| stream_nvidia(payload), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no", | |
| }, | |
| ) | |
| # Passthrough completions (legacy) | |
| async def completions(request: Request): | |
| verify_api_key(request) | |
| body = await request.json() | |
| model_req = body.get("model", "") | |
| try: | |
| nvidia_model = resolve_model(model_req) | |
| except HTTPException: | |
| nvidia_model = model_req | |
| body["model"] = nvidia_model | |
| body["stream"] = True | |
| return StreamingResponse( | |
| stream_nvidia(body), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no", | |
| }, | |
| ) | |
| async def health(): | |
| return {"status": "healthy"} |