Files changed (1) hide show
  1. gen.py +347 -47
gen.py CHANGED
@@ -66,6 +66,10 @@ MODEL_MAP = {
66
  FALLBACK_MODEL = "meta-llama/llama-4-scout-17b-16e-instruct"
67
  FALLBACK_PROVIDER = "groq"
68
 
 
 
 
 
69
 
70
  # ──────────────────────────────────────────────
71
  # CENTRAL ROUTING LOGIC
@@ -208,24 +212,173 @@ async def call_chat_completions(
208
  extra_body: Optional[Dict[str, Any]] = None,
209
  ) -> Dict[str, Any]:
210
  """
211
- Non-streaming chat-completions call.
212
-
213
- Returns the full upstream JSON payload.
214
- Raises HTTPException on upstream errors.
 
 
 
 
 
215
  """
216
  url, api_key = _get_provider_url_and_key(provider)
217
  headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
218
- body = {"model": model, "messages": messages, "stream": False}
 
 
219
  if extra_body:
220
  body.update(extra_body)
 
221
 
222
- async with httpx.AsyncClient(timeout=None) as client:
223
- r = await client.post(url, json=body, headers=headers)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
- if r.status_code != 200:
226
- raise HTTPException(status_code=r.status_code, detail=r.text[:1000])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
- return r.json()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
 
231
  def _extract_text_from_response(data: Dict[str, Any]) -> str:
@@ -253,6 +406,65 @@ def is_cinematic_image_prompt(prompt: str) -> bool:
253
  return False
254
 
255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  # ──────────────────────────────────────────────
257
  # IMAGE GENERATION
258
  # ──────────────────────────────────────────────
@@ -682,6 +894,10 @@ async def generate_text(
682
 
683
  await _check_chat_rate_limit(request, authorization, x_client_id)
684
 
 
 
 
 
685
  body["model"] = chosen_model
686
  stream = body.get("stream", False)
687
 
@@ -744,39 +960,79 @@ async def generate_text(
744
  sent_metadata = False
745
  async with httpx.AsyncClient(timeout=None) as client:
746
  async for chunk in stream_primary(client):
 
747
  if not sent_metadata:
748
- meta = {"router_metadata": {"model_name": MODEL_MAP.get(chosen_model, chosen_model)}}
 
 
 
 
749
  yield f"data: {json.dumps(meta)}\n\n"
750
  sent_metadata = True
751
 
752
- # Intercept the final non-[DONE] data chunk and normalize
753
- # the usage block so callers always see consistent field names.
754
- if chunk.startswith("data:") and "[DONE]" not in chunk:
 
 
 
 
755
  raw = chunk[5:].strip()
756
  try:
757
  obj = json.loads(raw)
758
- if isinstance(obj, dict) and "usage" in obj and isinstance(obj["usage"], dict):
759
- u = obj["usage"]
760
- input_tok = u.get("prompt_tokens") or u.get("input_tokens", 0)
761
- output_tok = u.get("completion_tokens") or u.get("output_tokens", 0)
762
- obj["usage"] = {
763
- "prompt_tokens": input_tok,
764
- "completion_tokens": output_tok,
765
- "total_tokens": input_tok + output_tok,
766
- "input_tokens": input_tok,
767
- "output_tokens": output_tok,
768
- }
769
- yield f"data: {json.dumps(obj)}\n\n"
770
- continue
771
  except Exception:
772
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
773
 
 
774
  yield chunk
775
 
776
  return StreamingResponse(
777
  event_generator(),
778
  media_type="text/event-stream",
779
- headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
 
 
 
 
780
  )
781
 
782
  # ── non-streaming ─────────────────────────
@@ -789,7 +1045,11 @@ async def generate_text(
789
  fb_url, fb_key = _get_provider_url_and_key(FALLBACK_PROVIDER)
790
  fallback_body = dict(body)
791
  fallback_body["model"] = FALLBACK_MODEL
792
- r = await client.post(fb_url, json=fallback_body, headers={"Authorization": f"Bearer {fb_key}"})
 
 
 
 
793
 
794
  content_type = (r.headers.get("content-type") or "").lower()
795
  if "application/json" in content_type:
@@ -798,22 +1058,35 @@ async def generate_text(
798
  except Exception:
799
  payload = {"error": "Upstream returned invalid JSON"}
800
  else:
801
- # Normalize usage: upstream may use prompt_tokens/completion_tokens
802
- # (OpenAI/Groq style) — rewrite to a consistent shape and add
803
- # router_metadata so callers always see the same fields.
804
- if "usage" in payload and isinstance(payload["usage"], dict):
805
- u = payload["usage"]
806
- input_tok = u.get("prompt_tokens") or u.get("input_tokens", 0)
807
- output_tok = u.get("completion_tokens") or u.get("output_tokens", 0)
808
- payload["usage"] = {
809
- "prompt_tokens": input_tok,
810
- "completion_tokens": output_tok,
811
- "total_tokens": input_tok + output_tok,
812
- # also include the OpenAI Responses-API names for clients that expect them
813
- "input_tokens": input_tok,
814
- "output_tokens": output_tok,
815
- }
816
- payload.setdefault("router_metadata", {})["model_name"] = MODEL_MAP.get(chosen_model, chosen_model)
 
 
 
 
 
 
 
 
 
 
 
 
 
817
  else:
818
  payload = {
819
  "error": "Upstream returned non-JSON response",
@@ -1063,8 +1336,24 @@ async def create_responses(
1063
  },
1064
  })
1065
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1066
  try:
1067
- text, tool_calls, input_tokens, output_tokens = await _generate()
1068
  except HTTPException as exc:
1069
  yield sse("response.failed", {
1070
  "type": "response.failed",
@@ -1076,6 +1365,17 @@ async def create_responses(
1076
  })
1077
  yield "data: [DONE]\n\n"
1078
  return
 
 
 
 
 
 
 
 
 
 
 
1079
 
1080
  output_index = 0
1081
 
 
66
  FALLBACK_MODEL = "meta-llama/llama-4-scout-17b-16e-instruct"
67
  FALLBACK_PROVIDER = "groq"
68
 
69
+ # Header that API-key authenticated clients send so we know to stream
70
+ # thinking tokens back to them.
71
+ API_KEY_HEADER = "x-api-key"
72
+
73
 
74
  # ──────────────────────────────────────────────
75
  # CENTRAL ROUTING LOGIC
 
212
  extra_body: Optional[Dict[str, Any]] = None,
213
  ) -> Dict[str, Any]:
214
  """
215
+ Resilient chat-completions call designed to survive Cloudflare 524 timeouts.
216
+
217
+ Strategy:
218
+ 1. Ask the upstream for a *streaming* response so bytes arrive before
219
+ Cloudflare's ~100 s idle timeout fires.
220
+ 2. Accumulate the stream into a single synthetic non-streaming payload
221
+ so callers don't need to change.
222
+ 3. Retry up to 2 times (with a short back-off) on 502/503/524.
223
+ 4. On exhausted retries fall through to the Groq fallback.
224
  """
225
  url, api_key = _get_provider_url_and_key(provider)
226
  headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
227
+
228
+ # Always request streaming upstream — we reassemble below.
229
+ body: Dict[str, Any] = {"model": model, "messages": messages, "stream": True}
230
  if extra_body:
231
  body.update(extra_body)
232
+ body["stream"] = True # force streaming even if caller passed stream=False
233
 
234
+ TRANSIENT = {502, 503, 524, 429}
235
+ MAX_ATTEMPTS = 3
236
+
237
+ last_exc: Optional[Exception] = None
238
+
239
+ for attempt in range(MAX_ATTEMPTS):
240
+ if attempt:
241
+ await asyncio.sleep(2 ** attempt) # 2 s, 4 s
242
+
243
+ try:
244
+ async with httpx.AsyncClient(timeout=httpx.Timeout(300.0, read=300.0)) as client:
245
+ async with client.stream("POST", url, json=body, headers=headers) as r:
246
+ # Transient upstream error — retry.
247
+ if r.status_code in TRANSIENT:
248
+ body_bytes = await r.aread()
249
+ last_exc = HTTPException(
250
+ status_code=r.status_code,
251
+ detail=body_bytes.decode("utf-8", errors="replace")[:500],
252
+ )
253
+ print(f"[call_chat_completions] attempt {attempt+1} got {r.status_code}, retrying…")
254
+ continue
255
+
256
+ if r.status_code != 200:
257
+ body_bytes = await r.aread()
258
+ raise HTTPException(
259
+ status_code=r.status_code,
260
+ detail=body_bytes.decode("utf-8", errors="replace")[:1000],
261
+ )
262
+
263
+ # ── Reassemble streaming SSE into a single response object ──
264
+ accumulated_content = ""
265
+ accumulated_reasoning = ""
266
+ tool_calls_map: Dict[int, Dict[str, Any]] = {}
267
+ usage: Dict[str, Any] = {}
268
+ finish_reason: Optional[str] = None
269
+ resp_id = ""
270
+ resp_model = model
271
+
272
+ async for line in r.aiter_lines():
273
+ if not line or not line.startswith("data:"):
274
+ continue
275
+ raw = line[5:].strip()
276
+ if raw == "[DONE]":
277
+ break
278
+ try:
279
+ obj = json.loads(raw)
280
+ except Exception:
281
+ continue
282
+
283
+ if not isinstance(obj, dict):
284
+ continue
285
+
286
+ resp_id = resp_id or obj.get("id", "")
287
+ resp_model = obj.get("model", resp_model)
288
+
289
+ if "usage" in obj and obj["usage"]:
290
+ usage = obj["usage"]
291
 
292
+ choices = obj.get("choices") or []
293
+ if not choices:
294
+ continue
295
+
296
+ choice = choices[0]
297
+ finish_reason = choice.get("finish_reason") or finish_reason
298
+ delta = choice.get("delta") or {}
299
+
300
+ # Accumulate text content.
301
+ dc = delta.get("content")
302
+ if dc:
303
+ accumulated_content += dc
304
+
305
+ # Accumulate reasoning / thinking tokens.
306
+ dr = delta.get("reasoning_content") or delta.get("reasoning")
307
+ if dr:
308
+ accumulated_reasoning += dr
309
+
310
+ # Accumulate tool-call argument chunks (streamed as fragments).
311
+ for tc_delta in (delta.get("tool_calls") or []):
312
+ idx = tc_delta.get("index", 0)
313
+ if idx not in tool_calls_map:
314
+ tool_calls_map[idx] = {
315
+ "id": tc_delta.get("id", ""),
316
+ "type": tc_delta.get("type", "function"),
317
+ "function": {"name": "", "arguments": ""},
318
+ }
319
+ existing = tool_calls_map[idx]
320
+ if tc_delta.get("id"):
321
+ existing["id"] = tc_delta["id"]
322
+ fn_delta = tc_delta.get("function") or {}
323
+ if fn_delta.get("name"):
324
+ existing["function"]["name"] += fn_delta["name"]
325
+ if fn_delta.get("arguments"):
326
+ existing["function"]["arguments"] += fn_delta["arguments"]
327
+
328
+ # Reassemble into a standard non-streaming response shape.
329
+ tool_calls_list = [tool_calls_map[i] for i in sorted(tool_calls_map)]
330
+
331
+ message: Dict[str, Any] = {"role": "assistant", "content": accumulated_content}
332
+ if accumulated_reasoning:
333
+ message["reasoning_content"] = accumulated_reasoning
334
+ if tool_calls_list:
335
+ message["tool_calls"] = tool_calls_list
336
+
337
+ return {
338
+ "id": resp_id,
339
+ "object": "chat.completion",
340
+ "model": resp_model,
341
+ "choices": [
342
+ {
343
+ "index": 0,
344
+ "message": message,
345
+ "finish_reason": finish_reason or "stop",
346
+ }
347
+ ],
348
+ "usage": usage,
349
+ }
350
 
351
+ except HTTPException:
352
+ raise
353
+ except (httpx.RemoteProtocolError, httpx.ReadError, httpx.ConnectError) as exc:
354
+ last_exc = exc
355
+ print(f"[call_chat_completions] attempt {attempt+1} network error: {exc}, retrying…")
356
+ continue
357
+
358
+ # All attempts exhausted — fall back to Groq.
359
+ print(f"[call_chat_completions] all attempts failed ({last_exc}), falling back to Groq")
360
+ fb_url, fb_key = _get_provider_url_and_key(FALLBACK_PROVIDER)
361
+ fb_headers = {"Authorization": f"Bearer {fb_key}", "Content-Type": "application/json"}
362
+ fallback_body = {
363
+ "model": FALLBACK_MODEL,
364
+ "messages": messages,
365
+ "stream": False,
366
+ }
367
+ if extra_body:
368
+ # Forward tools/tool_choice but not stream override.
369
+ for k in ("tools", "tool_choice"):
370
+ if k in extra_body:
371
+ fallback_body[k] = extra_body[k]
372
+
373
+ async with httpx.AsyncClient(timeout=httpx.Timeout(120.0)) as client:
374
+ fb_r = await client.post(fb_url, json=fallback_body, headers=fb_headers)
375
+
376
+ if fb_r.status_code != 200:
377
+ raise HTTPException(
378
+ status_code=fb_r.status_code,
379
+ detail=f"Primary and fallback both failed. Fallback: {fb_r.text[:500]}",
380
+ )
381
+ return fb_r.json()
382
 
383
 
384
  def _extract_text_from_response(data: Dict[str, Any]) -> str:
 
406
  return False
407
 
408
 
409
+ def _is_api_key_request(request: Request) -> bool:
410
+ """
411
+ Return True when the caller authenticated with an API key rather than a
412
+ session cookie / browser auth. We use this to decide whether to forward
413
+ think-tag / reasoning_content tokens to the client.
414
+ """
415
+ return bool(
416
+ request.headers.get(API_KEY_HEADER)
417
+ or request.headers.get("authorization", "").lower().startswith("bearer ")
418
+ )
419
+
420
+
421
+ def _inject_reasoning_into_chunk(obj: Dict[str, Any]) -> Dict[str, Any]:
422
+ """
423
+ Some navy models return thinking tokens in a non-standard
424
+ ``reasoning_content`` field inside each delta. When that field is
425
+ present we wrap it in <think>…</think> and prepend it to the regular
426
+ ``content`` delta so that every SSE-speaking client sees a single,
427
+ unified text stream.
428
+
429
+ The original ``reasoning_content`` field is preserved so clients that
430
+ know about it can still use it directly.
431
+ """
432
+ try:
433
+ delta = obj["choices"][0]["delta"]
434
+ except (KeyError, IndexError, TypeError):
435
+ return obj
436
+
437
+ reasoning = delta.get("reasoning_content") or delta.get("reasoning") or ""
438
+ content = delta.get("content") or ""
439
+
440
+ if reasoning and isinstance(reasoning, str):
441
+ # Wrap in <think> tags and prepend to the visible content delta.
442
+ wrapped = f"<think>{reasoning}</think>"
443
+ delta["content"] = wrapped + content
444
+ # Keep the raw field so native clients can parse it too.
445
+ delta["reasoning_content"] = reasoning
446
+ obj["choices"][0]["delta"] = delta
447
+
448
+ return obj
449
+
450
+
451
+ def _normalize_usage_block(obj: Dict[str, Any]) -> Dict[str, Any]:
452
+ """Rewrite the usage block to a canonical shape (in-place, returns obj)."""
453
+ if "usage" not in obj or not isinstance(obj.get("usage"), dict):
454
+ return obj
455
+ u = obj["usage"]
456
+ input_tok = u.get("prompt_tokens") or u.get("input_tokens", 0)
457
+ output_tok = u.get("completion_tokens") or u.get("output_tokens", 0)
458
+ obj["usage"] = {
459
+ "prompt_tokens": input_tok,
460
+ "completion_tokens": output_tok,
461
+ "total_tokens": input_tok + output_tok,
462
+ "input_tokens": input_tok,
463
+ "output_tokens": output_tok,
464
+ }
465
+ return obj
466
+
467
+
468
  # ──────────────────────────────────────────────
469
  # IMAGE GENERATION
470
  # ──────────────────────────────────────────────
 
894
 
895
  await _check_chat_rate_limit(request, authorization, x_client_id)
896
 
897
+ # Determine whether the caller is an API-key client that should receive
898
+ # raw thinking tokens.
899
+ forward_thinking = _is_api_key_request(request)
900
+
901
  body["model"] = chosen_model
902
  stream = body.get("stream", False)
903
 
 
960
  sent_metadata = False
961
  async with httpx.AsyncClient(timeout=None) as client:
962
  async for chunk in stream_primary(client):
963
+ # ── emit router metadata once as the very first SSE frame ──
964
  if not sent_metadata:
965
+ meta = {
966
+ "router_metadata": {
967
+ "model_name": MODEL_MAP.get(chosen_model, chosen_model)
968
+ }
969
+ }
970
  yield f"data: {json.dumps(meta)}\n\n"
971
  sent_metadata = True
972
 
973
+ # ── pass [DONE] straight through ──────────────────────────
974
+ if "data: [DONE]" in chunk:
975
+ yield chunk
976
+ continue
977
+
978
+ # ── process data: … lines ─────────────────────────────────
979
+ if chunk.startswith("data:"):
980
  raw = chunk[5:].strip()
981
  try:
982
  obj = json.loads(raw)
 
 
 
 
 
 
 
 
 
 
 
 
 
983
  except Exception:
984
+ # Not valid JSON — forward verbatim (keeps partial
985
+ # chunks from blocking the stream).
986
+ yield chunk
987
+ continue
988
+
989
+ if not isinstance(obj, dict):
990
+ yield chunk
991
+ continue
992
+
993
+ # Normalize usage block whenever it appears.
994
+ _normalize_usage_block(obj)
995
+
996
+ # ── thinking / reasoning tokens ───────────────────────
997
+ # Navy models may embed thinking in two ways:
998
+ #
999
+ # 1. As delta.reasoning_content (separate field)
1000
+ # 2. Inline inside delta.content wrapped in <think>…</think>
1001
+ #
1002
+ # For API-key callers we always surface both forms.
1003
+ # For browser/session callers we strip reasoning_content
1004
+ # so it doesn't confuse UI clients that don't expect it,
1005
+ # but <think> tags already present in content are left
1006
+ # alone (they arrived that way from upstream).
1007
+ if forward_thinking:
1008
+ # Merge reasoning_content into content as
1009
+ # <think>…</think> and keep the raw field.
1010
+ obj = _inject_reasoning_into_chunk(obj)
1011
+ else:
1012
+ # Strip the non-standard field so browser clients
1013
+ # don't see unexpected keys.
1014
+ try:
1015
+ delta = obj["choices"][0]["delta"]
1016
+ delta.pop("reasoning_content", None)
1017
+ delta.pop("reasoning", None)
1018
+ obj["choices"][0]["delta"] = delta
1019
+ except (KeyError, IndexError, TypeError):
1020
+ pass
1021
+
1022
+ yield f"data: {json.dumps(obj)}\n\n"
1023
+ continue
1024
 
1025
+ # ── any other line (comments, keep-alives, …) ─────────────
1026
  yield chunk
1027
 
1028
  return StreamingResponse(
1029
  event_generator(),
1030
  media_type="text/event-stream",
1031
+ headers={
1032
+ "Cache-Control": "no-cache",
1033
+ "Connection": "keep-alive",
1034
+ "X-Accel-Buffering": "no",
1035
+ },
1036
  )
1037
 
1038
  # ── non-streaming ─────────────────────────
 
1045
  fb_url, fb_key = _get_provider_url_and_key(FALLBACK_PROVIDER)
1046
  fallback_body = dict(body)
1047
  fallback_body["model"] = FALLBACK_MODEL
1048
+ r = await client.post(
1049
+ fb_url,
1050
+ json=fallback_body,
1051
+ headers={"Authorization": f"Bearer {fb_key}"},
1052
+ )
1053
 
1054
  content_type = (r.headers.get("content-type") or "").lower()
1055
  if "application/json" in content_type:
 
1058
  except Exception:
1059
  payload = {"error": "Upstream returned invalid JSON"}
1060
  else:
1061
+ # Normalize usage fields.
1062
+ _normalize_usage_block(payload)
1063
+
1064
+ # ── thinking tokens in non-streaming responses ────────────────────
1065
+ # Some navy models put thinking content in
1066
+ # message.reasoning_content. For API-key callers we prepend it to
1067
+ # message.content wrapped in <think>…</think>; for others we drop
1068
+ # the non-standard field.
1069
+ try:
1070
+ message = payload["choices"][0]["message"]
1071
+ reasoning = (
1072
+ message.pop("reasoning_content", None)
1073
+ or message.pop("reasoning", None)
1074
+ or ""
1075
+ )
1076
+ if reasoning and isinstance(reasoning, str):
1077
+ if forward_thinking:
1078
+ existing = message.get("content") or ""
1079
+ message["content"] = f"<think>{reasoning}</think>{existing}"
1080
+ # Restore the raw field for clients that want it.
1081
+ message["reasoning_content"] = reasoning
1082
+ # else: already popped — nothing to do.
1083
+ payload["choices"][0]["message"] = message
1084
+ except (KeyError, IndexError, TypeError):
1085
+ pass
1086
+
1087
+ payload.setdefault("router_metadata", {})["model_name"] = MODEL_MAP.get(
1088
+ chosen_model, chosen_model
1089
+ )
1090
  else:
1091
  payload = {
1092
  "error": "Upstream returned non-JSON response",
 
1336
  },
1337
  })
1338
 
1339
+ # ── Run _generate() in the background, pinging every 15 s ──────────────
1340
+ # Without keepalive bytes, Cloudflare (524) and Codex both drop the
1341
+ # connection while the model is thinking or accumulating tool arguments.
1342
+ # SSE comment lines (": ping") are invisible to application code but
1343
+ # reset every proxy's idle-timeout counter.
1344
+ PING_INTERVAL = 15 # seconds
1345
+ gen_task: asyncio.Task = asyncio.ensure_future(_generate())
1346
+
1347
+ while not gen_task.done():
1348
+ try:
1349
+ await asyncio.wait_for(asyncio.shield(gen_task), timeout=PING_INTERVAL)
1350
+ except asyncio.TimeoutError:
1351
+ yield ": ping\n\n"
1352
+ except Exception:
1353
+ break # real error — handled below
1354
+
1355
  try:
1356
+ text, tool_calls, input_tokens, output_tokens = gen_task.result()
1357
  except HTTPException as exc:
1358
  yield sse("response.failed", {
1359
  "type": "response.failed",
 
1365
  })
1366
  yield "data: [DONE]\n\n"
1367
  return
1368
+ except Exception as exc:
1369
+ yield sse("response.failed", {
1370
+ "type": "response.failed",
1371
+ "response": {
1372
+ "id": response_id, "object": "response",
1373
+ "created_at": ts, "status": "failed", "model": chosen_model,
1374
+ "error": {"code": "upstream_error", "message": str(exc)},
1375
+ },
1376
+ })
1377
+ yield "data: [DONE]\n\n"
1378
+ return
1379
 
1380
  output_index = 0
1381