oki692 commited on
Commit
f3b2bb4
Β·
verified Β·
1 Parent(s): 6ca3422

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +95 -106
main.py CHANGED
@@ -2,17 +2,17 @@
2
  Multi-model AI gateway endpoint β€” HF Spaces compatible.
3
  Authorization via 'connect' API key header.
4
  Streaming always enabled. Function calling supported.
 
5
  """
6
 
7
  import json
8
- import asyncio
9
  from typing import AsyncGenerator, Optional
10
 
 
11
  from fastapi import FastAPI, HTTPException, Header, Request
12
  from fastapi.responses import StreamingResponse
13
  from fastapi.middleware.cors import CORSMiddleware
14
  from pydantic import BaseModel, Field
15
- from openai import OpenAI
16
 
17
  from system_prompts import get_system_prompt
18
 
@@ -22,6 +22,7 @@ CONNECT_KEY = "connect"
22
 
23
  NVIDIA_BASE_URL = "https://integrate.api.nvidia.com/v1"
24
  NVIDIA_API_KEY = "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"
 
25
 
26
  # Model registry: display-name β†’ real model id + optional extra body
27
  MODELS = {
@@ -72,12 +73,9 @@ app.add_middleware(
72
  allow_headers=["*"],
73
  )
74
 
75
- client = OpenAI(base_url=NVIDIA_BASE_URL, api_key=NVIDIA_API_KEY)
76
-
77
  # ── Auth ─────────────────────────────────────────────────────────────────────
78
 
79
  def verify_key(authorization: Optional[str]) -> None:
80
- """Check Bearer token matches CONNECT_KEY."""
81
  if not authorization:
82
  raise HTTPException(status_code=401, detail="Missing Authorization header")
83
  scheme, _, token = authorization.partition(" ")
@@ -88,7 +86,7 @@ def verify_key(authorization: Optional[str]) -> None:
88
 
89
  class Message(BaseModel):
90
  role: str
91
- content: str | list # supports text or multipart
92
 
93
  class ToolFunction(BaseModel):
94
  name: str
@@ -100,7 +98,7 @@ class Tool(BaseModel):
100
  function: ToolFunction
101
 
102
  class ChatRequest(BaseModel):
103
- model: str = Field(..., description="Model name: Bielik-11b | GLM-4.7 | Mistral-Small-4 | DeepSeek-V3.1 | Kimi-K2")
104
  messages: list[Message]
105
  tools: Optional[list[Tool]] = None
106
  tool_choice: Optional[str | dict] = None
@@ -109,49 +107,88 @@ class ChatRequest(BaseModel):
109
  top_p: Optional[float] = None
110
  presence_penalty: Optional[float] = None
111
  frequency_penalty: Optional[float] = None
112
- inject_system_prompt: bool = Field(
113
- default=True,
114
- description="Prepend the model-specific system prompt automatically"
115
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
- # ── Stream helper ─────────────────────────────────────────────────────────────
118
 
119
  async def stream_nvidia(
120
  model_name: str,
121
  messages: list[dict],
122
- tools: Optional[list[dict]],
123
  tool_choice,
124
  kwargs: dict,
125
- extra_body: dict,
126
- ) -> AsyncGenerator[str, None]:
127
- """Yield SSE chunks from NVIDIA NIM in a thread-safe way."""
128
-
129
- params = {
130
- "model": MODELS[model_name]["model_id"],
131
- "messages": messages,
132
- "stream": True, # always True
133
- **kwargs,
134
  }
135
 
136
- if tools:
137
- params["tools"] = tools
138
- if tool_choice is not None:
139
- params["tool_choice"] = tool_choice
140
- if extra_body:
141
- params["extra_body"] = extra_body
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
- loop = asyncio.get_event_loop()
144
 
145
- def _call():
146
- return client.chat.completions.create(**params)
 
 
 
 
 
 
 
 
147
 
148
- stream = await loop.run_in_executor(None, _call)
149
 
150
- for chunk in stream:
151
- data = chunk.model_dump()
152
- yield f"data: {json.dumps(data)}\n\n"
153
 
154
- yield "data: [DONE]\n\n"
 
 
 
155
 
156
  # ── Endpoints ─────────────────────────────────────────────────────────────────
157
 
@@ -160,7 +197,7 @@ async def root():
160
  return {
161
  "service": "Multi-Model AI Gateway",
162
  "models": list(MODELS.keys()),
163
- "auth": "Bearer <connect-key>",
164
  "docs": "/docs",
165
  }
166
 
@@ -169,10 +206,7 @@ async def root():
169
  async def list_models(authorization: Optional[str] = Header(default=None)):
170
  verify_key(authorization)
171
  return {
172
- name: {
173
- "model_id": cfg["model_id"],
174
- "has_thinking": bool(cfg["extra_body"]),
175
- }
176
  for name, cfg in MODELS.items()
177
  }
178
 
@@ -190,54 +224,27 @@ async def chat(
190
  detail=f"Unknown model '{request.model}'. Available: {list(MODELS.keys())}",
191
  )
192
 
193
- cfg = MODELS[request.model]
194
-
195
- # Build messages list
196
- messages = [m.model_dump() for m in request.messages]
197
-
198
- # Inject per-model system prompt at position 0 if not already present
199
- if request.inject_system_prompt:
200
- system_prompt = get_system_prompt(request.model)
201
- if not messages or messages[0].get("role") != "system":
202
- messages.insert(0, {"role": "system", "content": system_prompt})
203
-
204
- # Optional params
205
- kwargs = {}
206
- for field in ("temperature", "max_tokens", "top_p", "presence_penalty", "frequency_penalty"):
207
- val = getattr(request, field)
208
- if val is not None:
209
- kwargs[field] = val
210
-
211
- tools = [t.model_dump() for t in request.tools] if request.tools else None
212
 
213
  return StreamingResponse(
214
- stream_nvidia(
215
- model_name=request.model,
216
- messages=messages,
217
- tools=tools,
218
- tool_choice=request.tool_choice,
219
- kwargs=kwargs,
220
- extra_body=cfg["extra_body"],
221
- ),
222
  media_type="text/event-stream",
223
- headers={
224
- "Cache-Control": "no-cache",
225
- "X-Accel-Buffering": "no",
226
- },
227
  )
228
 
229
 
230
- # ── Compatibility: OpenAI-style /v1/chat/completions ──────────────────────────
231
-
232
  @app.post("/v1/chat/completions")
233
  async def openai_compat(
234
  raw: Request,
235
  authorization: Optional[str] = Header(default=None),
236
  ):
237
- """
238
- Drop-in OpenAI-compatible endpoint.
239
- Pass model as one of the gateway model names (e.g. 'Kimi-K2').
240
- """
241
  verify_key(authorization)
242
  body = await raw.json()
243
 
@@ -248,35 +255,17 @@ async def openai_compat(
248
  detail=f"Unknown model '{model_name}'. Available: {list(MODELS.keys())}",
249
  )
250
 
251
- cfg = MODELS[model_name]
252
- messages = body.get("messages", [])
253
-
254
- inject = body.get("inject_system_prompt", True)
255
- if inject:
256
- system_prompt = get_system_prompt(model_name)
257
- if not messages or messages[0].get("role") != "system":
258
- messages.insert(0, {"role": "system", "content": system_prompt})
259
-
260
- kwargs = {}
261
- for field in ("temperature", "max_tokens", "top_p", "presence_penalty", "frequency_penalty"):
262
- if field in body:
263
- kwargs[field] = body[field]
264
-
265
- tools = body.get("tools")
266
  tool_choice = body.get("tool_choice")
267
 
268
  return StreamingResponse(
269
- stream_nvidia(
270
- model_name=model_name,
271
- messages=messages,
272
- tools=tools,
273
- tool_choice=tool_choice,
274
- kwargs=kwargs,
275
- extra_body=cfg["extra_body"],
276
- ),
277
  media_type="text/event-stream",
278
- headers={
279
- "Cache-Control": "no-cache",
280
- "X-Accel-Buffering": "no",
281
- },
282
- )
 
2
  Multi-model AI gateway endpoint β€” HF Spaces compatible.
3
  Authorization via 'connect' API key header.
4
  Streaming always enabled. Function calling supported.
5
+ Uses httpx async β€” no openai SDK network issues.
6
  """
7
 
8
  import json
 
9
  from typing import AsyncGenerator, Optional
10
 
11
+ import httpx
12
  from fastapi import FastAPI, HTTPException, Header, Request
13
  from fastapi.responses import StreamingResponse
14
  from fastapi.middleware.cors import CORSMiddleware
15
  from pydantic import BaseModel, Field
 
16
 
17
  from system_prompts import get_system_prompt
18
 
 
22
 
23
  NVIDIA_BASE_URL = "https://integrate.api.nvidia.com/v1"
24
  NVIDIA_API_KEY = "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"
25
+ NVIDIA_CHAT_URL = f"{NVIDIA_BASE_URL}/chat/completions"
26
 
27
  # Model registry: display-name β†’ real model id + optional extra body
28
  MODELS = {
 
73
  allow_headers=["*"],
74
  )
75
 
 
 
76
  # ── Auth ─────────────────────────────────────────────────────────────────────
77
 
78
  def verify_key(authorization: Optional[str]) -> None:
 
79
  if not authorization:
80
  raise HTTPException(status_code=401, detail="Missing Authorization header")
81
  scheme, _, token = authorization.partition(" ")
 
86
 
87
  class Message(BaseModel):
88
  role: str
89
+ content: str | list
90
 
91
  class ToolFunction(BaseModel):
92
  name: str
 
98
  function: ToolFunction
99
 
100
  class ChatRequest(BaseModel):
101
+ model: str = Field(..., description="Bielik-11b | GLM-4.7 | Mistral-Small-4 | DeepSeek-V3.1 | Kimi-K2")
102
  messages: list[Message]
103
  tools: Optional[list[Tool]] = None
104
  tool_choice: Optional[str | dict] = None
 
107
  top_p: Optional[float] = None
108
  presence_penalty: Optional[float] = None
109
  frequency_penalty: Optional[float] = None
110
+ inject_system_prompt: bool = Field(default=True)
111
+
112
+ # ── Core stream helper ────────────────────────────────────────────────────────
113
+
114
+ def _build_payload(model_name: str, messages: list[dict], tools, tool_choice, kwargs: dict) -> dict:
115
+ cfg = MODELS[model_name]
116
+ payload: dict = {
117
+ "model": cfg["model_id"],
118
+ "messages": messages,
119
+ "stream": True,
120
+ **kwargs,
121
+ }
122
+ if tools:
123
+ payload["tools"] = tools
124
+ if tool_choice is not None:
125
+ payload["tool_choice"] = tool_choice
126
+ # merge extra_body fields at top level (NVIDIA NIM style)
127
+ if cfg["extra_body"]:
128
+ payload.update(cfg["extra_body"])
129
+ return payload
130
 
 
131
 
132
  async def stream_nvidia(
133
  model_name: str,
134
  messages: list[dict],
135
+ tools,
136
  tool_choice,
137
  kwargs: dict,
138
+ ) -> AsyncGenerator[bytes, None]:
139
+ payload = _build_payload(model_name, messages, tools, tool_choice, kwargs)
140
+ headers = {
141
+ "Authorization": f"Bearer {NVIDIA_API_KEY}",
142
+ "Content-Type": "application/json",
143
+ "Accept": "text/event-stream",
 
 
 
144
  }
145
 
146
+ async with httpx.AsyncClient(timeout=httpx.Timeout(120.0, connect=10.0)) as client:
147
+ async with client.stream(
148
+ "POST",
149
+ NVIDIA_CHAT_URL,
150
+ headers=headers,
151
+ json=payload,
152
+ ) as response:
153
+ if response.status_code != 200:
154
+ body = await response.aread()
155
+ error_msg = body.decode(errors="replace")
156
+ yield f"data: {json.dumps({'error': error_msg, 'status': response.status_code})}\n\n".encode()
157
+ return
158
+
159
+ async for line in response.aiter_lines():
160
+ if line:
161
+ yield f"{line}\n\n".encode()
162
+
163
+ # ── Shared logic ──────────────────────────────────────────────────────────────
164
+
165
+ def prepare_messages(model_name: str, raw_messages: list[dict], inject: bool) -> list[dict]:
166
+ messages = list(raw_messages)
167
+ if inject:
168
+ system_prompt = get_system_prompt(model_name)
169
+ if not messages or messages[0].get("role") != "system":
170
+ messages.insert(0, {"role": "system", "content": system_prompt})
171
+ return messages
172
 
 
173
 
174
+ def extract_kwargs(source, fields: tuple) -> dict:
175
+ kwargs = {}
176
+ for field in fields:
177
+ if isinstance(source, dict):
178
+ val = source.get(field)
179
+ else:
180
+ val = getattr(source, field, None)
181
+ if val is not None:
182
+ kwargs[field] = val
183
+ return kwargs
184
 
 
185
 
186
+ OPTIONAL_FIELDS = ("temperature", "max_tokens", "top_p", "presence_penalty", "frequency_penalty")
 
 
187
 
188
+ SSE_HEADERS = {
189
+ "Cache-Control": "no-cache",
190
+ "X-Accel-Buffering": "no",
191
+ }
192
 
193
  # ── Endpoints ─────────────────────────────────────────────────────────────────
194
 
 
197
  return {
198
  "service": "Multi-Model AI Gateway",
199
  "models": list(MODELS.keys()),
200
+ "auth": "Bearer connect",
201
  "docs": "/docs",
202
  }
203
 
 
206
  async def list_models(authorization: Optional[str] = Header(default=None)):
207
  verify_key(authorization)
208
  return {
209
+ name: {"model_id": cfg["model_id"]}
 
 
 
210
  for name, cfg in MODELS.items()
211
  }
212
 
 
224
  detail=f"Unknown model '{request.model}'. Available: {list(MODELS.keys())}",
225
  )
226
 
227
+ messages = prepare_messages(
228
+ request.model,
229
+ [m.model_dump() for m in request.messages],
230
+ request.inject_system_prompt,
231
+ )
232
+ kwargs = extract_kwargs(request, OPTIONAL_FIELDS)
233
+ tools = [t.model_dump() for t in request.tools] if request.tools else None
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
  return StreamingResponse(
236
+ stream_nvidia(request.model, messages, tools, request.tool_choice, kwargs),
 
 
 
 
 
 
 
237
  media_type="text/event-stream",
238
+ headers=SSE_HEADERS,
 
 
 
239
  )
240
 
241
 
 
 
242
  @app.post("/v1/chat/completions")
243
  async def openai_compat(
244
  raw: Request,
245
  authorization: Optional[str] = Header(default=None),
246
  ):
247
+ """OpenAI-compatible drop-in. Use gateway model names as 'model'."""
 
 
 
248
  verify_key(authorization)
249
  body = await raw.json()
250
 
 
255
  detail=f"Unknown model '{model_name}'. Available: {list(MODELS.keys())}",
256
  )
257
 
258
+ messages = prepare_messages(
259
+ model_name,
260
+ body.get("messages", []),
261
+ body.get("inject_system_prompt", True),
262
+ )
263
+ kwargs = extract_kwargs(body, OPTIONAL_FIELDS)
264
+ tools = body.get("tools")
 
 
 
 
 
 
 
 
265
  tool_choice = body.get("tool_choice")
266
 
267
  return StreamingResponse(
268
+ stream_nvidia(model_name, messages, tools, tool_choice, kwargs),
 
 
 
 
 
 
 
269
  media_type="text/event-stream",
270
+ headers=SSE_HEADERS,
271
+ )