""" OpenAI-compatible /v1 API Gateway Proxies to NVIDIA NIM API with streaming always enabled, function calling support, and per-model system prompts. Deploy on Hugging Face Spaces (Docker). Authorization: Bearer connect """ import json import time import uuid import asyncio from typing import Any, AsyncGenerator import httpx from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, JSONResponse from pydantic import BaseModel, Field from system_prompts import SYSTEM_PROMPTS, MODEL_MAP, REVERSE_MODEL_MAP, EXTRA_BODY_MODELS # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- NVIDIA_BASE_URL = "https://integrate.api.nvidia.com/v1" NVIDIA_API_KEY = "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" GATEWAY_API_KEY = "connect" # --------------------------------------------------------------------------- # App # --------------------------------------------------------------------------- app = FastAPI( title="AI Gateway", description="OpenAI-compatible gateway to NVIDIA NIM models", version="1.0.0", ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --------------------------------------------------------------------------- # Auth # --------------------------------------------------------------------------- def verify_api_key(request: Request) -> None: auth = request.headers.get("Authorization", "") if not auth.startswith("Bearer "): raise HTTPException(status_code=401, detail="Missing Bearer token") token = auth.removeprefix("Bearer ").strip() if token != GATEWAY_API_KEY: raise HTTPException(status_code=401, detail="Invalid API key") # --------------------------------------------------------------------------- # Pydantic models (OpenAI-compatible) # --------------------------------------------------------------------------- class FunctionParameters(BaseModel): type: str = "object" properties: dict[str, Any] = {} required: list[str] = [] class FunctionDef(BaseModel): name: str description: str | None = None parameters: FunctionParameters | None = None class Tool(BaseModel): type: str = "function" function: FunctionDef class ToolChoice(BaseModel): type: str = "function" function: dict[str, str] | None = None class Message(BaseModel): role: str content: str | list[Any] | None = None name: str | None = None tool_calls: list[Any] | None = None tool_call_id: str | None = None class ChatCompletionRequest(BaseModel): model: str messages: list[Message] temperature: float | None = None top_p: float | None = None max_tokens: int | None = None tools: list[Tool] | None = None tool_choice: str | ToolChoice | None = None # stream is ALWAYS True – ignored if provided, always forced to True stream: bool = True stop: list[str] | str | None = None presence_penalty: float | None = None frequency_penalty: float | None = None seed: int | None = None n: int | None = None logprobs: bool | None = None top_logprobs: int | None = None user: str | None = None # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def resolve_model(requested: str) -> str: """Map display name or raw NVIDIA model ID to NVIDIA model ID.""" if requested in MODEL_MAP: return MODEL_MAP[requested] if requested in REVERSE_MODEL_MAP: return requested # already a raw ID raise HTTPException( status_code=400, detail=f"Unknown model '{requested}'. Available: {list(MODEL_MAP.keys())}", ) def get_display_name(nvidia_id: str) -> str: return REVERSE_MODEL_MAP.get(nvidia_id, nvidia_id) def inject_system_prompt(messages: list[Message], display_name: str) -> list[dict]: """Inject per-model system prompt if not already present.""" prompt = SYSTEM_PROMPTS.get(display_name) serialized = [m.model_dump(exclude_none=True) for m in messages] if prompt: has_system = any(m["role"] == "system" for m in serialized) if not has_system: serialized = [{"role": "system", "content": prompt}] + serialized return serialized def build_nvidia_payload(req: ChatCompletionRequest, nvidia_model: str) -> dict: display = get_display_name(nvidia_model) messages = inject_system_prompt(req.messages, display) payload: dict[str, Any] = { "model": nvidia_model, "messages": messages, "stream": True, # ALWAYS TRUE } # Optional params if req.temperature is not None: payload["temperature"] = req.temperature if req.top_p is not None: payload["top_p"] = req.top_p if req.max_tokens is not None: payload["max_tokens"] = req.max_tokens if req.stop is not None: payload["stop"] = req.stop if req.presence_penalty is not None: payload["presence_penalty"] = req.presence_penalty if req.frequency_penalty is not None: payload["frequency_penalty"] = req.frequency_penalty if req.seed is not None: payload["seed"] = req.seed if req.n is not None: payload["n"] = req.n if req.user is not None: payload["user"] = req.user # Function calling / tools if req.tools: payload["tools"] = [t.model_dump(exclude_none=True) for t in req.tools] if req.tool_choice is not None: if isinstance(req.tool_choice, str): payload["tool_choice"] = req.tool_choice else: payload["tool_choice"] = req.tool_choice.model_dump(exclude_none=True) # Extra body for specific models (e.g. GLM-4.7 thinking params) extra = EXTRA_BODY_MODELS.get(nvidia_model, {}) payload.update(extra) return payload # --------------------------------------------------------------------------- # SSE streaming proxy # --------------------------------------------------------------------------- async def stream_nvidia(payload: dict) -> AsyncGenerator[bytes, None]: headers = { "Authorization": f"Bearer {NVIDIA_API_KEY}", "Content-Type": "application/json", "Accept": "text/event-stream", } async with httpx.AsyncClient(timeout=300) as client: async with client.stream( "POST", f"{NVIDIA_BASE_URL}/chat/completions", headers=headers, json=payload, ) as response: if response.status_code != 200: body = await response.aread() error_detail = body.decode(errors="replace") error_chunk = { "error": { "message": f"Upstream error {response.status_code}: {error_detail}", "type": "upstream_error", "code": response.status_code, } } yield f"data: {json.dumps(error_chunk)}\n\n".encode() yield b"data: [DONE]\n\n" return async for line in response.aiter_lines(): if line.startswith("data: "): yield f"{line}\n\n".encode() if line == "data: [DONE]": return elif line.strip(): # Pass through any unexpected lines yield f"data: {line}\n\n".encode() # --------------------------------------------------------------------------- # Routes # --------------------------------------------------------------------------- @app.get("/") async def root(): return {"status": "ok", "service": "AI Gateway", "version": "1.0.0"} @app.get("/v1/models") async def list_models(request: Request): verify_api_key(request) now = int(time.time()) models = [] for display_name in MODEL_MAP: models.append({ "id": display_name, "object": "model", "created": now, "owned_by": "ai-gateway", }) return {"object": "list", "data": models} @app.post("/v1/chat/completions") async def chat_completions(request: Request, req: ChatCompletionRequest): verify_api_key(request) nvidia_model = resolve_model(req.model) payload = build_nvidia_payload(req, nvidia_model) return StreamingResponse( stream_nvidia(payload), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) # Passthrough completions (legacy) @app.post("/v1/completions") async def completions(request: Request): verify_api_key(request) body = await request.json() model_req = body.get("model", "") try: nvidia_model = resolve_model(model_req) except HTTPException: nvidia_model = model_req body["model"] = nvidia_model body["stream"] = True return StreamingResponse( stream_nvidia(body), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) @app.get("/health") async def health(): return {"status": "healthy"}