Spaces:
Running
Running
Update gen.py
#7
by incognitolm - opened
gen.py
CHANGED
|
@@ -66,6 +66,10 @@ MODEL_MAP = {
|
|
| 66 |
FALLBACK_MODEL = "meta-llama/llama-4-scout-17b-16e-instruct"
|
| 67 |
FALLBACK_PROVIDER = "groq"
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
# ──────────────────────────────────────────────
|
| 71 |
# CENTRAL ROUTING LOGIC
|
|
@@ -208,24 +212,173 @@ async def call_chat_completions(
|
|
| 208 |
extra_body: Optional[Dict[str, Any]] = None,
|
| 209 |
) -> Dict[str, Any]:
|
| 210 |
"""
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
"""
|
| 216 |
url, api_key = _get_provider_url_and_key(provider)
|
| 217 |
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
| 218 |
-
|
|
|
|
|
|
|
| 219 |
if extra_body:
|
| 220 |
body.update(extra_body)
|
|
|
|
| 221 |
|
| 222 |
-
|
| 223 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
|
| 225 |
-
|
| 226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
|
| 228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
|
| 231 |
def _extract_text_from_response(data: Dict[str, Any]) -> str:
|
|
@@ -253,6 +406,65 @@ def is_cinematic_image_prompt(prompt: str) -> bool:
|
|
| 253 |
return False
|
| 254 |
|
| 255 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
# ──────────────────────────────────────────────
|
| 257 |
# IMAGE GENERATION
|
| 258 |
# ──────────────────────────────────────────────
|
|
@@ -682,6 +894,10 @@ async def generate_text(
|
|
| 682 |
|
| 683 |
await _check_chat_rate_limit(request, authorization, x_client_id)
|
| 684 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 685 |
body["model"] = chosen_model
|
| 686 |
stream = body.get("stream", False)
|
| 687 |
|
|
@@ -744,39 +960,79 @@ async def generate_text(
|
|
| 744 |
sent_metadata = False
|
| 745 |
async with httpx.AsyncClient(timeout=None) as client:
|
| 746 |
async for chunk in stream_primary(client):
|
|
|
|
| 747 |
if not sent_metadata:
|
| 748 |
-
meta = {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 749 |
yield f"data: {json.dumps(meta)}\n\n"
|
| 750 |
sent_metadata = True
|
| 751 |
|
| 752 |
-
#
|
| 753 |
-
|
| 754 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 755 |
raw = chunk[5:].strip()
|
| 756 |
try:
|
| 757 |
obj = json.loads(raw)
|
| 758 |
-
if isinstance(obj, dict) and "usage" in obj and isinstance(obj["usage"], dict):
|
| 759 |
-
u = obj["usage"]
|
| 760 |
-
input_tok = u.get("prompt_tokens") or u.get("input_tokens", 0)
|
| 761 |
-
output_tok = u.get("completion_tokens") or u.get("output_tokens", 0)
|
| 762 |
-
obj["usage"] = {
|
| 763 |
-
"prompt_tokens": input_tok,
|
| 764 |
-
"completion_tokens": output_tok,
|
| 765 |
-
"total_tokens": input_tok + output_tok,
|
| 766 |
-
"input_tokens": input_tok,
|
| 767 |
-
"output_tokens": output_tok,
|
| 768 |
-
}
|
| 769 |
-
yield f"data: {json.dumps(obj)}\n\n"
|
| 770 |
-
continue
|
| 771 |
except Exception:
|
| 772 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 773 |
|
|
|
|
| 774 |
yield chunk
|
| 775 |
|
| 776 |
return StreamingResponse(
|
| 777 |
event_generator(),
|
| 778 |
media_type="text/event-stream",
|
| 779 |
-
headers={
|
|
|
|
|
|
|
|
|
|
|
|
|
| 780 |
)
|
| 781 |
|
| 782 |
# ── non-streaming ─────────────────────────
|
|
@@ -789,7 +1045,11 @@ async def generate_text(
|
|
| 789 |
fb_url, fb_key = _get_provider_url_and_key(FALLBACK_PROVIDER)
|
| 790 |
fallback_body = dict(body)
|
| 791 |
fallback_body["model"] = FALLBACK_MODEL
|
| 792 |
-
r = await client.post(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 793 |
|
| 794 |
content_type = (r.headers.get("content-type") or "").lower()
|
| 795 |
if "application/json" in content_type:
|
|
@@ -798,22 +1058,35 @@ async def generate_text(
|
|
| 798 |
except Exception:
|
| 799 |
payload = {"error": "Upstream returned invalid JSON"}
|
| 800 |
else:
|
| 801 |
-
# Normalize usage
|
| 802 |
-
|
| 803 |
-
|
| 804 |
-
|
| 805 |
-
|
| 806 |
-
|
| 807 |
-
|
| 808 |
-
|
| 809 |
-
|
| 810 |
-
|
| 811 |
-
|
| 812 |
-
|
| 813 |
-
"
|
| 814 |
-
"
|
| 815 |
-
|
| 816 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 817 |
else:
|
| 818 |
payload = {
|
| 819 |
"error": "Upstream returned non-JSON response",
|
|
@@ -1063,8 +1336,24 @@ async def create_responses(
|
|
| 1063 |
},
|
| 1064 |
})
|
| 1065 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1066 |
try:
|
| 1067 |
-
text, tool_calls, input_tokens, output_tokens =
|
| 1068 |
except HTTPException as exc:
|
| 1069 |
yield sse("response.failed", {
|
| 1070 |
"type": "response.failed",
|
|
@@ -1076,6 +1365,17 @@ async def create_responses(
|
|
| 1076 |
})
|
| 1077 |
yield "data: [DONE]\n\n"
|
| 1078 |
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1079 |
|
| 1080 |
output_index = 0
|
| 1081 |
|
|
|
|
| 66 |
FALLBACK_MODEL = "meta-llama/llama-4-scout-17b-16e-instruct"
|
| 67 |
FALLBACK_PROVIDER = "groq"
|
| 68 |
|
| 69 |
+
# Header that API-key authenticated clients send so we know to stream
|
| 70 |
+
# thinking tokens back to them.
|
| 71 |
+
API_KEY_HEADER = "x-api-key"
|
| 72 |
+
|
| 73 |
|
| 74 |
# ──────────────────────────────────────────────
|
| 75 |
# CENTRAL ROUTING LOGIC
|
|
|
|
| 212 |
extra_body: Optional[Dict[str, Any]] = None,
|
| 213 |
) -> Dict[str, Any]:
|
| 214 |
"""
|
| 215 |
+
Resilient chat-completions call designed to survive Cloudflare 524 timeouts.
|
| 216 |
+
|
| 217 |
+
Strategy:
|
| 218 |
+
1. Ask the upstream for a *streaming* response so bytes arrive before
|
| 219 |
+
Cloudflare's ~100 s idle timeout fires.
|
| 220 |
+
2. Accumulate the stream into a single synthetic non-streaming payload
|
| 221 |
+
so callers don't need to change.
|
| 222 |
+
3. Retry up to 2 times (with a short back-off) on 502/503/524.
|
| 223 |
+
4. On exhausted retries fall through to the Groq fallback.
|
| 224 |
"""
|
| 225 |
url, api_key = _get_provider_url_and_key(provider)
|
| 226 |
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
| 227 |
+
|
| 228 |
+
# Always request streaming upstream — we reassemble below.
|
| 229 |
+
body: Dict[str, Any] = {"model": model, "messages": messages, "stream": True}
|
| 230 |
if extra_body:
|
| 231 |
body.update(extra_body)
|
| 232 |
+
body["stream"] = True # force streaming even if caller passed stream=False
|
| 233 |
|
| 234 |
+
TRANSIENT = {502, 503, 524, 429}
|
| 235 |
+
MAX_ATTEMPTS = 3
|
| 236 |
+
|
| 237 |
+
last_exc: Optional[Exception] = None
|
| 238 |
+
|
| 239 |
+
for attempt in range(MAX_ATTEMPTS):
|
| 240 |
+
if attempt:
|
| 241 |
+
await asyncio.sleep(2 ** attempt) # 2 s, 4 s
|
| 242 |
+
|
| 243 |
+
try:
|
| 244 |
+
async with httpx.AsyncClient(timeout=httpx.Timeout(300.0, read=300.0)) as client:
|
| 245 |
+
async with client.stream("POST", url, json=body, headers=headers) as r:
|
| 246 |
+
# Transient upstream error — retry.
|
| 247 |
+
if r.status_code in TRANSIENT:
|
| 248 |
+
body_bytes = await r.aread()
|
| 249 |
+
last_exc = HTTPException(
|
| 250 |
+
status_code=r.status_code,
|
| 251 |
+
detail=body_bytes.decode("utf-8", errors="replace")[:500],
|
| 252 |
+
)
|
| 253 |
+
print(f"[call_chat_completions] attempt {attempt+1} got {r.status_code}, retrying…")
|
| 254 |
+
continue
|
| 255 |
+
|
| 256 |
+
if r.status_code != 200:
|
| 257 |
+
body_bytes = await r.aread()
|
| 258 |
+
raise HTTPException(
|
| 259 |
+
status_code=r.status_code,
|
| 260 |
+
detail=body_bytes.decode("utf-8", errors="replace")[:1000],
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
# ── Reassemble streaming SSE into a single response object ──
|
| 264 |
+
accumulated_content = ""
|
| 265 |
+
accumulated_reasoning = ""
|
| 266 |
+
tool_calls_map: Dict[int, Dict[str, Any]] = {}
|
| 267 |
+
usage: Dict[str, Any] = {}
|
| 268 |
+
finish_reason: Optional[str] = None
|
| 269 |
+
resp_id = ""
|
| 270 |
+
resp_model = model
|
| 271 |
+
|
| 272 |
+
async for line in r.aiter_lines():
|
| 273 |
+
if not line or not line.startswith("data:"):
|
| 274 |
+
continue
|
| 275 |
+
raw = line[5:].strip()
|
| 276 |
+
if raw == "[DONE]":
|
| 277 |
+
break
|
| 278 |
+
try:
|
| 279 |
+
obj = json.loads(raw)
|
| 280 |
+
except Exception:
|
| 281 |
+
continue
|
| 282 |
+
|
| 283 |
+
if not isinstance(obj, dict):
|
| 284 |
+
continue
|
| 285 |
+
|
| 286 |
+
resp_id = resp_id or obj.get("id", "")
|
| 287 |
+
resp_model = obj.get("model", resp_model)
|
| 288 |
+
|
| 289 |
+
if "usage" in obj and obj["usage"]:
|
| 290 |
+
usage = obj["usage"]
|
| 291 |
|
| 292 |
+
choices = obj.get("choices") or []
|
| 293 |
+
if not choices:
|
| 294 |
+
continue
|
| 295 |
+
|
| 296 |
+
choice = choices[0]
|
| 297 |
+
finish_reason = choice.get("finish_reason") or finish_reason
|
| 298 |
+
delta = choice.get("delta") or {}
|
| 299 |
+
|
| 300 |
+
# Accumulate text content.
|
| 301 |
+
dc = delta.get("content")
|
| 302 |
+
if dc:
|
| 303 |
+
accumulated_content += dc
|
| 304 |
+
|
| 305 |
+
# Accumulate reasoning / thinking tokens.
|
| 306 |
+
dr = delta.get("reasoning_content") or delta.get("reasoning")
|
| 307 |
+
if dr:
|
| 308 |
+
accumulated_reasoning += dr
|
| 309 |
+
|
| 310 |
+
# Accumulate tool-call argument chunks (streamed as fragments).
|
| 311 |
+
for tc_delta in (delta.get("tool_calls") or []):
|
| 312 |
+
idx = tc_delta.get("index", 0)
|
| 313 |
+
if idx not in tool_calls_map:
|
| 314 |
+
tool_calls_map[idx] = {
|
| 315 |
+
"id": tc_delta.get("id", ""),
|
| 316 |
+
"type": tc_delta.get("type", "function"),
|
| 317 |
+
"function": {"name": "", "arguments": ""},
|
| 318 |
+
}
|
| 319 |
+
existing = tool_calls_map[idx]
|
| 320 |
+
if tc_delta.get("id"):
|
| 321 |
+
existing["id"] = tc_delta["id"]
|
| 322 |
+
fn_delta = tc_delta.get("function") or {}
|
| 323 |
+
if fn_delta.get("name"):
|
| 324 |
+
existing["function"]["name"] += fn_delta["name"]
|
| 325 |
+
if fn_delta.get("arguments"):
|
| 326 |
+
existing["function"]["arguments"] += fn_delta["arguments"]
|
| 327 |
+
|
| 328 |
+
# Reassemble into a standard non-streaming response shape.
|
| 329 |
+
tool_calls_list = [tool_calls_map[i] for i in sorted(tool_calls_map)]
|
| 330 |
+
|
| 331 |
+
message: Dict[str, Any] = {"role": "assistant", "content": accumulated_content}
|
| 332 |
+
if accumulated_reasoning:
|
| 333 |
+
message["reasoning_content"] = accumulated_reasoning
|
| 334 |
+
if tool_calls_list:
|
| 335 |
+
message["tool_calls"] = tool_calls_list
|
| 336 |
+
|
| 337 |
+
return {
|
| 338 |
+
"id": resp_id,
|
| 339 |
+
"object": "chat.completion",
|
| 340 |
+
"model": resp_model,
|
| 341 |
+
"choices": [
|
| 342 |
+
{
|
| 343 |
+
"index": 0,
|
| 344 |
+
"message": message,
|
| 345 |
+
"finish_reason": finish_reason or "stop",
|
| 346 |
+
}
|
| 347 |
+
],
|
| 348 |
+
"usage": usage,
|
| 349 |
+
}
|
| 350 |
|
| 351 |
+
except HTTPException:
|
| 352 |
+
raise
|
| 353 |
+
except (httpx.RemoteProtocolError, httpx.ReadError, httpx.ConnectError) as exc:
|
| 354 |
+
last_exc = exc
|
| 355 |
+
print(f"[call_chat_completions] attempt {attempt+1} network error: {exc}, retrying…")
|
| 356 |
+
continue
|
| 357 |
+
|
| 358 |
+
# All attempts exhausted — fall back to Groq.
|
| 359 |
+
print(f"[call_chat_completions] all attempts failed ({last_exc}), falling back to Groq")
|
| 360 |
+
fb_url, fb_key = _get_provider_url_and_key(FALLBACK_PROVIDER)
|
| 361 |
+
fb_headers = {"Authorization": f"Bearer {fb_key}", "Content-Type": "application/json"}
|
| 362 |
+
fallback_body = {
|
| 363 |
+
"model": FALLBACK_MODEL,
|
| 364 |
+
"messages": messages,
|
| 365 |
+
"stream": False,
|
| 366 |
+
}
|
| 367 |
+
if extra_body:
|
| 368 |
+
# Forward tools/tool_choice but not stream override.
|
| 369 |
+
for k in ("tools", "tool_choice"):
|
| 370 |
+
if k in extra_body:
|
| 371 |
+
fallback_body[k] = extra_body[k]
|
| 372 |
+
|
| 373 |
+
async with httpx.AsyncClient(timeout=httpx.Timeout(120.0)) as client:
|
| 374 |
+
fb_r = await client.post(fb_url, json=fallback_body, headers=fb_headers)
|
| 375 |
+
|
| 376 |
+
if fb_r.status_code != 200:
|
| 377 |
+
raise HTTPException(
|
| 378 |
+
status_code=fb_r.status_code,
|
| 379 |
+
detail=f"Primary and fallback both failed. Fallback: {fb_r.text[:500]}",
|
| 380 |
+
)
|
| 381 |
+
return fb_r.json()
|
| 382 |
|
| 383 |
|
| 384 |
def _extract_text_from_response(data: Dict[str, Any]) -> str:
|
|
|
|
| 406 |
return False
|
| 407 |
|
| 408 |
|
| 409 |
+
def _is_api_key_request(request: Request) -> bool:
|
| 410 |
+
"""
|
| 411 |
+
Return True when the caller authenticated with an API key rather than a
|
| 412 |
+
session cookie / browser auth. We use this to decide whether to forward
|
| 413 |
+
think-tag / reasoning_content tokens to the client.
|
| 414 |
+
"""
|
| 415 |
+
return bool(
|
| 416 |
+
request.headers.get(API_KEY_HEADER)
|
| 417 |
+
or request.headers.get("authorization", "").lower().startswith("bearer ")
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def _inject_reasoning_into_chunk(obj: Dict[str, Any]) -> Dict[str, Any]:
|
| 422 |
+
"""
|
| 423 |
+
Some navy models return thinking tokens in a non-standard
|
| 424 |
+
``reasoning_content`` field inside each delta. When that field is
|
| 425 |
+
present we wrap it in <think>…</think> and prepend it to the regular
|
| 426 |
+
``content`` delta so that every SSE-speaking client sees a single,
|
| 427 |
+
unified text stream.
|
| 428 |
+
|
| 429 |
+
The original ``reasoning_content`` field is preserved so clients that
|
| 430 |
+
know about it can still use it directly.
|
| 431 |
+
"""
|
| 432 |
+
try:
|
| 433 |
+
delta = obj["choices"][0]["delta"]
|
| 434 |
+
except (KeyError, IndexError, TypeError):
|
| 435 |
+
return obj
|
| 436 |
+
|
| 437 |
+
reasoning = delta.get("reasoning_content") or delta.get("reasoning") or ""
|
| 438 |
+
content = delta.get("content") or ""
|
| 439 |
+
|
| 440 |
+
if reasoning and isinstance(reasoning, str):
|
| 441 |
+
# Wrap in <think> tags and prepend to the visible content delta.
|
| 442 |
+
wrapped = f"<think>{reasoning}</think>"
|
| 443 |
+
delta["content"] = wrapped + content
|
| 444 |
+
# Keep the raw field so native clients can parse it too.
|
| 445 |
+
delta["reasoning_content"] = reasoning
|
| 446 |
+
obj["choices"][0]["delta"] = delta
|
| 447 |
+
|
| 448 |
+
return obj
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def _normalize_usage_block(obj: Dict[str, Any]) -> Dict[str, Any]:
|
| 452 |
+
"""Rewrite the usage block to a canonical shape (in-place, returns obj)."""
|
| 453 |
+
if "usage" not in obj or not isinstance(obj.get("usage"), dict):
|
| 454 |
+
return obj
|
| 455 |
+
u = obj["usage"]
|
| 456 |
+
input_tok = u.get("prompt_tokens") or u.get("input_tokens", 0)
|
| 457 |
+
output_tok = u.get("completion_tokens") or u.get("output_tokens", 0)
|
| 458 |
+
obj["usage"] = {
|
| 459 |
+
"prompt_tokens": input_tok,
|
| 460 |
+
"completion_tokens": output_tok,
|
| 461 |
+
"total_tokens": input_tok + output_tok,
|
| 462 |
+
"input_tokens": input_tok,
|
| 463 |
+
"output_tokens": output_tok,
|
| 464 |
+
}
|
| 465 |
+
return obj
|
| 466 |
+
|
| 467 |
+
|
| 468 |
# ──────────────────────────────────────────────
|
| 469 |
# IMAGE GENERATION
|
| 470 |
# ──────────────────────────────────────────────
|
|
|
|
| 894 |
|
| 895 |
await _check_chat_rate_limit(request, authorization, x_client_id)
|
| 896 |
|
| 897 |
+
# Determine whether the caller is an API-key client that should receive
|
| 898 |
+
# raw thinking tokens.
|
| 899 |
+
forward_thinking = _is_api_key_request(request)
|
| 900 |
+
|
| 901 |
body["model"] = chosen_model
|
| 902 |
stream = body.get("stream", False)
|
| 903 |
|
|
|
|
| 960 |
sent_metadata = False
|
| 961 |
async with httpx.AsyncClient(timeout=None) as client:
|
| 962 |
async for chunk in stream_primary(client):
|
| 963 |
+
# ── emit router metadata once as the very first SSE frame ──
|
| 964 |
if not sent_metadata:
|
| 965 |
+
meta = {
|
| 966 |
+
"router_metadata": {
|
| 967 |
+
"model_name": MODEL_MAP.get(chosen_model, chosen_model)
|
| 968 |
+
}
|
| 969 |
+
}
|
| 970 |
yield f"data: {json.dumps(meta)}\n\n"
|
| 971 |
sent_metadata = True
|
| 972 |
|
| 973 |
+
# ── pass [DONE] straight through ──────────────────────────
|
| 974 |
+
if "data: [DONE]" in chunk:
|
| 975 |
+
yield chunk
|
| 976 |
+
continue
|
| 977 |
+
|
| 978 |
+
# ── process data: … lines ─────────────────────────────────
|
| 979 |
+
if chunk.startswith("data:"):
|
| 980 |
raw = chunk[5:].strip()
|
| 981 |
try:
|
| 982 |
obj = json.loads(raw)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 983 |
except Exception:
|
| 984 |
+
# Not valid JSON — forward verbatim (keeps partial
|
| 985 |
+
# chunks from blocking the stream).
|
| 986 |
+
yield chunk
|
| 987 |
+
continue
|
| 988 |
+
|
| 989 |
+
if not isinstance(obj, dict):
|
| 990 |
+
yield chunk
|
| 991 |
+
continue
|
| 992 |
+
|
| 993 |
+
# Normalize usage block whenever it appears.
|
| 994 |
+
_normalize_usage_block(obj)
|
| 995 |
+
|
| 996 |
+
# ── thinking / reasoning tokens ───────────────────────
|
| 997 |
+
# Navy models may embed thinking in two ways:
|
| 998 |
+
#
|
| 999 |
+
# 1. As delta.reasoning_content (separate field)
|
| 1000 |
+
# 2. Inline inside delta.content wrapped in <think>…</think>
|
| 1001 |
+
#
|
| 1002 |
+
# For API-key callers we always surface both forms.
|
| 1003 |
+
# For browser/session callers we strip reasoning_content
|
| 1004 |
+
# so it doesn't confuse UI clients that don't expect it,
|
| 1005 |
+
# but <think> tags already present in content are left
|
| 1006 |
+
# alone (they arrived that way from upstream).
|
| 1007 |
+
if forward_thinking:
|
| 1008 |
+
# Merge reasoning_content into content as
|
| 1009 |
+
# <think>…</think> and keep the raw field.
|
| 1010 |
+
obj = _inject_reasoning_into_chunk(obj)
|
| 1011 |
+
else:
|
| 1012 |
+
# Strip the non-standard field so browser clients
|
| 1013 |
+
# don't see unexpected keys.
|
| 1014 |
+
try:
|
| 1015 |
+
delta = obj["choices"][0]["delta"]
|
| 1016 |
+
delta.pop("reasoning_content", None)
|
| 1017 |
+
delta.pop("reasoning", None)
|
| 1018 |
+
obj["choices"][0]["delta"] = delta
|
| 1019 |
+
except (KeyError, IndexError, TypeError):
|
| 1020 |
+
pass
|
| 1021 |
+
|
| 1022 |
+
yield f"data: {json.dumps(obj)}\n\n"
|
| 1023 |
+
continue
|
| 1024 |
|
| 1025 |
+
# ── any other line (comments, keep-alives, …) ─────────────
|
| 1026 |
yield chunk
|
| 1027 |
|
| 1028 |
return StreamingResponse(
|
| 1029 |
event_generator(),
|
| 1030 |
media_type="text/event-stream",
|
| 1031 |
+
headers={
|
| 1032 |
+
"Cache-Control": "no-cache",
|
| 1033 |
+
"Connection": "keep-alive",
|
| 1034 |
+
"X-Accel-Buffering": "no",
|
| 1035 |
+
},
|
| 1036 |
)
|
| 1037 |
|
| 1038 |
# ── non-streaming ─────────────────────────
|
|
|
|
| 1045 |
fb_url, fb_key = _get_provider_url_and_key(FALLBACK_PROVIDER)
|
| 1046 |
fallback_body = dict(body)
|
| 1047 |
fallback_body["model"] = FALLBACK_MODEL
|
| 1048 |
+
r = await client.post(
|
| 1049 |
+
fb_url,
|
| 1050 |
+
json=fallback_body,
|
| 1051 |
+
headers={"Authorization": f"Bearer {fb_key}"},
|
| 1052 |
+
)
|
| 1053 |
|
| 1054 |
content_type = (r.headers.get("content-type") or "").lower()
|
| 1055 |
if "application/json" in content_type:
|
|
|
|
| 1058 |
except Exception:
|
| 1059 |
payload = {"error": "Upstream returned invalid JSON"}
|
| 1060 |
else:
|
| 1061 |
+
# Normalize usage fields.
|
| 1062 |
+
_normalize_usage_block(payload)
|
| 1063 |
+
|
| 1064 |
+
# ── thinking tokens in non-streaming responses ────────────────────
|
| 1065 |
+
# Some navy models put thinking content in
|
| 1066 |
+
# message.reasoning_content. For API-key callers we prepend it to
|
| 1067 |
+
# message.content wrapped in <think>…</think>; for others we drop
|
| 1068 |
+
# the non-standard field.
|
| 1069 |
+
try:
|
| 1070 |
+
message = payload["choices"][0]["message"]
|
| 1071 |
+
reasoning = (
|
| 1072 |
+
message.pop("reasoning_content", None)
|
| 1073 |
+
or message.pop("reasoning", None)
|
| 1074 |
+
or ""
|
| 1075 |
+
)
|
| 1076 |
+
if reasoning and isinstance(reasoning, str):
|
| 1077 |
+
if forward_thinking:
|
| 1078 |
+
existing = message.get("content") or ""
|
| 1079 |
+
message["content"] = f"<think>{reasoning}</think>{existing}"
|
| 1080 |
+
# Restore the raw field for clients that want it.
|
| 1081 |
+
message["reasoning_content"] = reasoning
|
| 1082 |
+
# else: already popped — nothing to do.
|
| 1083 |
+
payload["choices"][0]["message"] = message
|
| 1084 |
+
except (KeyError, IndexError, TypeError):
|
| 1085 |
+
pass
|
| 1086 |
+
|
| 1087 |
+
payload.setdefault("router_metadata", {})["model_name"] = MODEL_MAP.get(
|
| 1088 |
+
chosen_model, chosen_model
|
| 1089 |
+
)
|
| 1090 |
else:
|
| 1091 |
payload = {
|
| 1092 |
"error": "Upstream returned non-JSON response",
|
|
|
|
| 1336 |
},
|
| 1337 |
})
|
| 1338 |
|
| 1339 |
+
# ── Run _generate() in the background, pinging every 15 s ──────────────
|
| 1340 |
+
# Without keepalive bytes, Cloudflare (524) and Codex both drop the
|
| 1341 |
+
# connection while the model is thinking or accumulating tool arguments.
|
| 1342 |
+
# SSE comment lines (": ping") are invisible to application code but
|
| 1343 |
+
# reset every proxy's idle-timeout counter.
|
| 1344 |
+
PING_INTERVAL = 15 # seconds
|
| 1345 |
+
gen_task: asyncio.Task = asyncio.ensure_future(_generate())
|
| 1346 |
+
|
| 1347 |
+
while not gen_task.done():
|
| 1348 |
+
try:
|
| 1349 |
+
await asyncio.wait_for(asyncio.shield(gen_task), timeout=PING_INTERVAL)
|
| 1350 |
+
except asyncio.TimeoutError:
|
| 1351 |
+
yield ": ping\n\n"
|
| 1352 |
+
except Exception:
|
| 1353 |
+
break # real error — handled below
|
| 1354 |
+
|
| 1355 |
try:
|
| 1356 |
+
text, tool_calls, input_tokens, output_tokens = gen_task.result()
|
| 1357 |
except HTTPException as exc:
|
| 1358 |
yield sse("response.failed", {
|
| 1359 |
"type": "response.failed",
|
|
|
|
| 1365 |
})
|
| 1366 |
yield "data: [DONE]\n\n"
|
| 1367 |
return
|
| 1368 |
+
except Exception as exc:
|
| 1369 |
+
yield sse("response.failed", {
|
| 1370 |
+
"type": "response.failed",
|
| 1371 |
+
"response": {
|
| 1372 |
+
"id": response_id, "object": "response",
|
| 1373 |
+
"created_at": ts, "status": "failed", "model": chosen_model,
|
| 1374 |
+
"error": {"code": "upstream_error", "message": str(exc)},
|
| 1375 |
+
},
|
| 1376 |
+
})
|
| 1377 |
+
yield "data: [DONE]\n\n"
|
| 1378 |
+
return
|
| 1379 |
|
| 1380 |
output_index = 0
|
| 1381 |
|