oki692 commited on
Commit
1a6f0f5
·
verified ·
1 Parent(s): f46f8fd

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +222 -201
main.py CHANGED
@@ -1,271 +1,292 @@
1
  """
2
- Multi-model AI gateway endpoint — HF Spaces compatible.
3
- Authorization via 'connect' API key header.
4
- Streaming always enabled. Function calling supported.
5
- Uses httpx async — no openai SDK network issues.
 
 
6
  """
7
 
8
  import json
9
- from typing import AsyncGenerator, Optional
 
 
 
10
 
11
  import httpx
12
- from fastapi import FastAPI, HTTPException, Header, Request
13
- from fastapi.responses import StreamingResponse
14
  from fastapi.middleware.cors import CORSMiddleware
 
15
  from pydantic import BaseModel, Field
16
 
17
- from system_prompts import get_system_prompt
18
-
19
- # ── Config ──────────────────────────────────────────────────────────────────
20
 
21
- CONNECT_KEY = "connect"
 
 
22
 
23
  NVIDIA_BASE_URL = "https://integrate.api.nvidia.com/v1"
24
  NVIDIA_API_KEY = "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"
25
- NVIDIA_CHAT_URL = f"{NVIDIA_BASE_URL}/chat/completions"
26
-
27
- # Model registry: display-name → real model id + optional extra body
28
- MODELS = {
29
- "Bielik-11b": {
30
- "model_id": "speakleash/bielik-11b-v2.6-instruct",
31
- "extra_body": {
32
- "chat_template_kwargs": {
33
- "enable_thinking": False,
34
- "clear_thinking": True,
35
- }
36
- },
37
- },
38
- "GLM-4.7": {
39
- "model_id": "z-ai/glm4.7",
40
- "extra_body": {
41
- "chat_template_kwargs": {
42
- "enable_thinking": False,
43
- "clear_thinking": True,
44
- }
45
- },
46
- },
47
- "Mistral-Small-4": {
48
- "model_id": "mistralai/mistral-small-4-119b-2603",
49
- "extra_body": {},
50
- },
51
- "DeepSeek-V3.1": {
52
- "model_id": "deepseek-ai/deepseek-v3.1",
53
- "extra_body": {},
54
- },
55
- "Kimi-K2": {
56
- "model_id": "moonshotai/kimi-k2-instruct",
57
- "extra_body": {},
58
- },
59
- }
60
-
61
- # ── FastAPI ──────────────────────────────────────────────────────────────────
62
 
63
  app = FastAPI(
64
- title="Multi-Model AI Gateway",
 
65
  version="1.0.0",
66
- description="Streaming endpoint for Bielik-11b, GLM-4.7, Mistral-Small-4, DeepSeek-V3.1, Kimi-K2",
67
  )
68
 
69
  app.add_middleware(
70
  CORSMiddleware,
71
  allow_origins=["*"],
 
72
  allow_methods=["*"],
73
  allow_headers=["*"],
74
  )
75
 
76
- # ── Auth ─────────────────────────────────────────────────────────────────────
 
 
77
 
78
- def verify_key(authorization: Optional[str]) -> None:
79
- if not authorization:
80
- raise HTTPException(status_code=401, detail="Missing Authorization header")
81
- scheme, _, token = authorization.partition(" ")
82
- if scheme.lower() != "bearer" or token != CONNECT_KEY:
83
- raise HTTPException(status_code=403, detail="Invalid API key")
 
84
 
85
- # ── Schemas ───────────────────────────────────────────────────────────────────
 
 
86
 
87
- class Message(BaseModel):
88
- role: str
89
- content: str | list
 
90
 
91
- class ToolFunction(BaseModel):
92
  name: str
93
- description: Optional[str] = None
94
- parameters: Optional[dict] = None
95
 
96
  class Tool(BaseModel):
97
  type: str = "function"
98
- function: ToolFunction
 
 
 
 
 
 
 
 
 
 
 
99
 
100
- class ChatRequest(BaseModel):
101
- model: str = Field(..., description="Bielik-11b | GLM-4.7 | Mistral-Small-4 | DeepSeek-V3.1 | Kimi-K2")
102
  messages: list[Message]
103
- tools: Optional[list[Tool]] = None
104
- tool_choice: Optional[str | dict] = None
105
- temperature: Optional[float] = None
106
- max_tokens: Optional[int] = None
107
- top_p: Optional[float] = None
108
- presence_penalty: Optional[float] = None
109
- frequency_penalty: Optional[float] = None
110
- inject_system_prompt: bool = Field(default=True)
111
-
112
- # ── Core stream helper ────────────────────────────────────────────────────────
113
-
114
- def _build_payload(model_name: str, messages: list[dict], tools, tool_choice, kwargs: dict) -> dict:
115
- cfg = MODELS[model_name]
116
- payload: dict = {
117
- "model": cfg["model_id"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  "messages": messages,
119
- "stream": True,
120
- **kwargs,
121
  }
122
- if tools:
123
- payload["tools"] = tools
124
- if tool_choice is not None:
125
- payload["tool_choice"] = tool_choice
126
- # merge extra_body fields at top level (NVIDIA NIM style)
127
- if cfg["extra_body"]:
128
- payload.update(cfg["extra_body"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  return payload
130
 
 
 
 
131
 
132
- async def stream_nvidia(
133
- model_name: str,
134
- messages: list[dict],
135
- tools,
136
- tool_choice,
137
- kwargs: dict,
138
- ) -> AsyncGenerator[bytes, None]:
139
- payload = _build_payload(model_name, messages, tools, tool_choice, kwargs)
140
  headers = {
141
  "Authorization": f"Bearer {NVIDIA_API_KEY}",
142
- "Content-Type": "application/json",
143
- "Accept": "text/event-stream",
144
  }
145
 
146
- async with httpx.AsyncClient(timeout=httpx.Timeout(120.0, connect=10.0)) as client:
147
  async with client.stream(
148
  "POST",
149
- NVIDIA_CHAT_URL,
150
  headers=headers,
151
  json=payload,
152
  ) as response:
153
  if response.status_code != 200:
154
  body = await response.aread()
155
- error_msg = body.decode(errors="replace")
156
- yield f"data: {json.dumps({'error': error_msg, 'status': response.status_code})}\n\n".encode()
 
 
 
 
 
 
 
 
157
  return
158
 
159
  async for line in response.aiter_lines():
160
- if line:
161
  yield f"{line}\n\n".encode()
 
 
 
 
 
162
 
163
- # ── Shared logic ──────────────────────────────────────────────────────────────
164
-
165
- def prepare_messages(model_name: str, raw_messages: list[dict], inject: bool) -> list[dict]:
166
- messages = list(raw_messages)
167
- if inject:
168
- system_prompt = get_system_prompt(model_name)
169
- if not messages or messages[0].get("role") != "system":
170
- messages.insert(0, {"role": "system", "content": system_prompt})
171
- return messages
172
-
173
-
174
- def extract_kwargs(source, fields: tuple) -> dict:
175
- kwargs = {}
176
- for field in fields:
177
- if isinstance(source, dict):
178
- val = source.get(field)
179
- else:
180
- val = getattr(source, field, None)
181
- if val is not None:
182
- kwargs[field] = val
183
- return kwargs
184
-
185
-
186
- OPTIONAL_FIELDS = ("temperature", "max_tokens", "top_p", "presence_penalty", "frequency_penalty")
187
-
188
- SSE_HEADERS = {
189
- "Cache-Control": "no-cache",
190
- "X-Accel-Buffering": "no",
191
- }
192
-
193
- # ── Endpoints ─────────────────────────────────────────────────────────────────
194
 
195
  @app.get("/")
196
  async def root():
197
- return {
198
- "service": "Multi-Model AI Gateway",
199
- "models": list(MODELS.keys()),
200
- "auth": "Bearer connect",
201
- "docs": "/docs",
202
- }
203
-
204
-
205
- @app.get("/models")
206
- async def list_models(authorization: Optional[str] = Header(default=None)):
207
- verify_key(authorization)
208
- return {
209
- name: {"model_id": cfg["model_id"]}
210
- for name, cfg in MODELS.items()
211
- }
212
 
 
 
 
213
 
214
- @app.post("/chat")
215
- async def chat(
216
- request: ChatRequest,
217
- authorization: Optional[str] = Header(default=None),
218
- ):
219
- verify_key(authorization)
220
-
221
- if request.model not in MODELS:
222
- raise HTTPException(
223
- status_code=400,
224
- detail=f"Unknown model '{request.model}'. Available: {list(MODELS.keys())}",
225
- )
226
-
227
- messages = prepare_messages(
228
- request.model,
229
- [m.model_dump() for m in request.messages],
230
- request.inject_system_prompt,
231
- )
232
- kwargs = extract_kwargs(request, OPTIONAL_FIELDS)
233
- tools = [t.model_dump() for t in request.tools] if request.tools else None
234
 
235
  return StreamingResponse(
236
- stream_nvidia(request.model, messages, tools, request.tool_choice, kwargs),
237
  media_type="text/event-stream",
238
- headers=SSE_HEADERS,
 
 
 
 
239
  )
240
 
 
 
 
 
 
241
 
242
- @app.post("/v1/chat/completions")
243
- async def openai_compat(
244
- raw: Request,
245
- authorization: Optional[str] = Header(default=None),
246
- ):
247
- """OpenAI-compatible drop-in. Use gateway model names as 'model'."""
248
- verify_key(authorization)
249
- body = await raw.json()
250
-
251
- model_name = body.get("model", "")
252
- if model_name not in MODELS:
253
- raise HTTPException(
254
- status_code=400,
255
- detail=f"Unknown model '{model_name}'. Available: {list(MODELS.keys())}",
256
- )
257
-
258
- messages = prepare_messages(
259
- model_name,
260
- body.get("messages", []),
261
- body.get("inject_system_prompt", True),
262
- )
263
- kwargs = extract_kwargs(body, OPTIONAL_FIELDS)
264
- tools = body.get("tools")
265
- tool_choice = body.get("tool_choice")
266
 
267
  return StreamingResponse(
268
- stream_nvidia(model_name, messages, tools, tool_choice, kwargs),
269
  media_type="text/event-stream",
270
- headers=SSE_HEADERS,
271
- )
 
 
 
 
 
 
 
 
 
1
  """
2
+ OpenAI-compatible /v1 API Gateway
3
+ Proxies to NVIDIA NIM API with streaming always enabled,
4
+ function calling support, and per-model system prompts.
5
+
6
+ Deploy on Hugging Face Spaces (Docker).
7
+ Authorization: Bearer connect
8
  """
9
 
10
  import json
11
+ import time
12
+ import uuid
13
+ import asyncio
14
+ from typing import Any, AsyncGenerator
15
 
16
  import httpx
17
+ from fastapi import FastAPI, HTTPException, Request
 
18
  from fastapi.middleware.cors import CORSMiddleware
19
+ from fastapi.responses import StreamingResponse, JSONResponse
20
  from pydantic import BaseModel, Field
21
 
22
+ from system_prompts import SYSTEM_PROMPTS, MODEL_MAP, REVERSE_MODEL_MAP, EXTRA_BODY_MODELS
 
 
23
 
24
+ # ---------------------------------------------------------------------------
25
+ # Config
26
+ # ---------------------------------------------------------------------------
27
 
28
  NVIDIA_BASE_URL = "https://integrate.api.nvidia.com/v1"
29
  NVIDIA_API_KEY = "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"
30
+ GATEWAY_API_KEY = "connect"
31
+
32
+ # ---------------------------------------------------------------------------
33
+ # App
34
+ # ---------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  app = FastAPI(
37
+ title="AI Gateway",
38
+ description="OpenAI-compatible gateway to NVIDIA NIM models",
39
  version="1.0.0",
 
40
  )
41
 
42
  app.add_middleware(
43
  CORSMiddleware,
44
  allow_origins=["*"],
45
+ allow_credentials=True,
46
  allow_methods=["*"],
47
  allow_headers=["*"],
48
  )
49
 
50
+ # ---------------------------------------------------------------------------
51
+ # Auth
52
+ # ---------------------------------------------------------------------------
53
 
54
+ def verify_api_key(request: Request) -> None:
55
+ auth = request.headers.get("Authorization", "")
56
+ if not auth.startswith("Bearer "):
57
+ raise HTTPException(status_code=401, detail="Missing Bearer token")
58
+ token = auth.removeprefix("Bearer ").strip()
59
+ if token != GATEWAY_API_KEY:
60
+ raise HTTPException(status_code=401, detail="Invalid API key")
61
 
62
+ # ---------------------------------------------------------------------------
63
+ # Pydantic models (OpenAI-compatible)
64
+ # ---------------------------------------------------------------------------
65
 
66
+ class FunctionParameters(BaseModel):
67
+ type: str = "object"
68
+ properties: dict[str, Any] = {}
69
+ required: list[str] = []
70
 
71
+ class FunctionDef(BaseModel):
72
  name: str
73
+ description: str | None = None
74
+ parameters: FunctionParameters | None = None
75
 
76
  class Tool(BaseModel):
77
  type: str = "function"
78
+ function: FunctionDef
79
+
80
+ class ToolChoice(BaseModel):
81
+ type: str = "function"
82
+ function: dict[str, str] | None = None
83
+
84
+ class Message(BaseModel):
85
+ role: str
86
+ content: str | list[Any] | None = None
87
+ name: str | None = None
88
+ tool_calls: list[Any] | None = None
89
+ tool_call_id: str | None = None
90
 
91
+ class ChatCompletionRequest(BaseModel):
92
+ model: str
93
  messages: list[Message]
94
+ temperature: float | None = None
95
+ top_p: float | None = None
96
+ max_tokens: int | None = None
97
+ tools: list[Tool] | None = None
98
+ tool_choice: str | ToolChoice | None = None
99
+ # stream is ALWAYS True – ignored if provided, always forced to True
100
+ stream: bool = True
101
+ stop: list[str] | str | None = None
102
+ presence_penalty: float | None = None
103
+ frequency_penalty: float | None = None
104
+ seed: int | None = None
105
+ n: int | None = None
106
+ logprobs: bool | None = None
107
+ top_logprobs: int | None = None
108
+ user: str | None = None
109
+
110
+ # ---------------------------------------------------------------------------
111
+ # Helpers
112
+ # ---------------------------------------------------------------------------
113
+
114
+ def resolve_model(requested: str) -> str:
115
+ """Map display name or raw NVIDIA model ID to NVIDIA model ID."""
116
+ if requested in MODEL_MAP:
117
+ return MODEL_MAP[requested]
118
+ if requested in REVERSE_MODEL_MAP:
119
+ return requested # already a raw ID
120
+ raise HTTPException(
121
+ status_code=400,
122
+ detail=f"Unknown model '{requested}'. Available: {list(MODEL_MAP.keys())}",
123
+ )
124
+
125
+ def get_display_name(nvidia_id: str) -> str:
126
+ return REVERSE_MODEL_MAP.get(nvidia_id, nvidia_id)
127
+
128
+ def inject_system_prompt(messages: list[Message], display_name: str) -> list[dict]:
129
+ """Inject per-model system prompt if not already present."""
130
+ prompt = SYSTEM_PROMPTS.get(display_name)
131
+ serialized = [m.model_dump(exclude_none=True) for m in messages]
132
+
133
+ if prompt:
134
+ has_system = any(m["role"] == "system" for m in serialized)
135
+ if not has_system:
136
+ serialized = [{"role": "system", "content": prompt}] + serialized
137
+
138
+ return serialized
139
+
140
+ def build_nvidia_payload(req: ChatCompletionRequest, nvidia_model: str) -> dict:
141
+ display = get_display_name(nvidia_model)
142
+ messages = inject_system_prompt(req.messages, display)
143
+
144
+ payload: dict[str, Any] = {
145
+ "model": nvidia_model,
146
  "messages": messages,
147
+ "stream": True, # ALWAYS TRUE
 
148
  }
149
+
150
+ # Optional params
151
+ if req.temperature is not None:
152
+ payload["temperature"] = req.temperature
153
+ if req.top_p is not None:
154
+ payload["top_p"] = req.top_p
155
+ if req.max_tokens is not None:
156
+ payload["max_tokens"] = req.max_tokens
157
+ if req.stop is not None:
158
+ payload["stop"] = req.stop
159
+ if req.presence_penalty is not None:
160
+ payload["presence_penalty"] = req.presence_penalty
161
+ if req.frequency_penalty is not None:
162
+ payload["frequency_penalty"] = req.frequency_penalty
163
+ if req.seed is not None:
164
+ payload["seed"] = req.seed
165
+ if req.n is not None:
166
+ payload["n"] = req.n
167
+ if req.user is not None:
168
+ payload["user"] = req.user
169
+
170
+ # Function calling / tools
171
+ if req.tools:
172
+ payload["tools"] = [t.model_dump(exclude_none=True) for t in req.tools]
173
+ if req.tool_choice is not None:
174
+ if isinstance(req.tool_choice, str):
175
+ payload["tool_choice"] = req.tool_choice
176
+ else:
177
+ payload["tool_choice"] = req.tool_choice.model_dump(exclude_none=True)
178
+
179
+ # Extra body for specific models (e.g. GLM-4.7 thinking params)
180
+ extra = EXTRA_BODY_MODELS.get(nvidia_model, {})
181
+ payload.update(extra)
182
+
183
  return payload
184
 
185
+ # ---------------------------------------------------------------------------
186
+ # SSE streaming proxy
187
+ # ---------------------------------------------------------------------------
188
 
189
+ async def stream_nvidia(payload: dict) -> AsyncGenerator[bytes, None]:
 
 
 
 
 
 
 
190
  headers = {
191
  "Authorization": f"Bearer {NVIDIA_API_KEY}",
192
+ "Content-Type": "application/json",
193
+ "Accept": "text/event-stream",
194
  }
195
 
196
+ async with httpx.AsyncClient(timeout=300) as client:
197
  async with client.stream(
198
  "POST",
199
+ f"{NVIDIA_BASE_URL}/chat/completions",
200
  headers=headers,
201
  json=payload,
202
  ) as response:
203
  if response.status_code != 200:
204
  body = await response.aread()
205
+ error_detail = body.decode(errors="replace")
206
+ error_chunk = {
207
+ "error": {
208
+ "message": f"Upstream error {response.status_code}: {error_detail}",
209
+ "type": "upstream_error",
210
+ "code": response.status_code,
211
+ }
212
+ }
213
+ yield f"data: {json.dumps(error_chunk)}\n\n".encode()
214
+ yield b"data: [DONE]\n\n"
215
  return
216
 
217
  async for line in response.aiter_lines():
218
+ if line.startswith("data: "):
219
  yield f"{line}\n\n".encode()
220
+ if line == "data: [DONE]":
221
+ return
222
+ elif line.strip():
223
+ # Pass through any unexpected lines
224
+ yield f"data: {line}\n\n".encode()
225
 
226
+ # ---------------------------------------------------------------------------
227
+ # Routes
228
+ # ---------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
  @app.get("/")
231
  async def root():
232
+ return {"status": "ok", "service": "AI Gateway", "version": "1.0.0"}
233
+
234
+ @app.get("/v1/models")
235
+ async def list_models(request: Request):
236
+ verify_api_key(request)
237
+ now = int(time.time())
238
+ models = []
239
+ for display_name in MODEL_MAP:
240
+ models.append({
241
+ "id": display_name,
242
+ "object": "model",
243
+ "created": now,
244
+ "owned_by": "ai-gateway",
245
+ })
246
+ return {"object": "list", "data": models}
247
 
248
+ @app.post("/v1/chat/completions")
249
+ async def chat_completions(request: Request, req: ChatCompletionRequest):
250
+ verify_api_key(request)
251
 
252
+ nvidia_model = resolve_model(req.model)
253
+ payload = build_nvidia_payload(req, nvidia_model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
  return StreamingResponse(
256
+ stream_nvidia(payload),
257
  media_type="text/event-stream",
258
+ headers={
259
+ "Cache-Control": "no-cache",
260
+ "Connection": "keep-alive",
261
+ "X-Accel-Buffering": "no",
262
+ },
263
  )
264
 
265
+ # Passthrough completions (legacy)
266
+ @app.post("/v1/completions")
267
+ async def completions(request: Request):
268
+ verify_api_key(request)
269
+ body = await request.json()
270
 
271
+ model_req = body.get("model", "")
272
+ try:
273
+ nvidia_model = resolve_model(model_req)
274
+ except HTTPException:
275
+ nvidia_model = model_req
276
+
277
+ body["model"] = nvidia_model
278
+ body["stream"] = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
 
280
  return StreamingResponse(
281
+ stream_nvidia(body),
282
  media_type="text/event-stream",
283
+ headers={
284
+ "Cache-Control": "no-cache",
285
+ "Connection": "keep-alive",
286
+ "X-Accel-Buffering": "no",
287
+ },
288
+ )
289
+
290
+ @app.get("/health")
291
+ async def health():
292
+ return {"status": "healthy"}