sharktide commited on
Commit
d45cb85
·
verified ·
1 Parent(s): 7e9f3f5

Update gen.py

Browse files
Files changed (1) hide show
  1. gen.py +21 -42
gen.py CHANGED
@@ -789,6 +789,10 @@ async def generate_text(
789
 
790
  yield chunk
791
 
 
 
 
 
792
  return StreamingResponse(
793
  event_generator(),
794
  media_type="text/event-stream",
@@ -945,15 +949,7 @@ def return_models_openai():
945
 
946
  import time
947
  import uuid
948
-
949
  def _normalize_responses_input(input_field) -> list[dict]:
950
- """
951
- Coerce the Responses API `input` field into a standard messages[] list.
952
-
953
- Accepted shapes:
954
- • str → [{"role":"user","content":"..."}]
955
- • list of message-like dicts → pass through, normalising content parts
956
- """
957
  if isinstance(input_field, str):
958
  return [{"role": "user", "content": input_field}]
959
 
@@ -982,9 +978,6 @@ def _normalize_responses_input(input_field) -> list[dict]:
982
 
983
 
984
  def _wrap_responses_output(chat_payload: dict, model_name: str) -> dict:
985
- """
986
- Wrap a standard chat-completions JSON response into the Responses API shape.
987
- """
988
  choices = chat_payload.get("choices", [])
989
  output = []
990
 
@@ -1033,10 +1026,6 @@ def _wrap_responses_output(chat_payload: dict, model_name: str) -> dict:
1033
  def _wrap_responses_stream_chunk(
1034
  line: str, response_id: str, model_name: str, sent_created: bool
1035
  ) -> tuple[str, bool, bool]:
1036
- """
1037
- Returns (translated_sse_string, sent_created, is_done).
1038
- is_done=True means the upstream [DONE] was processed and response.completed was emitted.
1039
- """
1040
  if not line.startswith("data:"):
1041
  return line + "\n", sent_created, False
1042
 
@@ -1141,22 +1130,6 @@ async def create_response(
1141
  authorization: Optional[str] = Header(None),
1142
  x_client_id: Optional[str] = Header(None),
1143
  ):
1144
- """
1145
- OpenAI Responses API-compatible endpoint.
1146
-
1147
- Accepts the Responses API request shape, normalises it into the chat
1148
- completions format, routes it through the existing generate_text logic,
1149
- and wraps the result back into the Responses API shape.
1150
-
1151
- Supported fields:
1152
- • input (str | list) — required
1153
- • model (str) — accepted but ignored (router decides)
1154
- • stream (bool) — optional, default False
1155
- • tools (list) — optional, forwarded as-is
1156
- • tool_choice (str|dict) — optional, forwarded as-is
1157
- • temperature (float) — optional, forwarded
1158
- • max_output_tokens (int) — mapped to max_tokens
1159
- """
1160
  body = await request.json()
1161
 
1162
  input_field = body.get("input")
@@ -1178,8 +1151,13 @@ async def create_response(
1178
 
1179
  raw_body = json.dumps(chat_body).encode()
1180
 
 
 
 
 
1181
  scope = dict(request.scope)
1182
  scope["path"] = "/gen/chat/completions"
 
1183
  scope["headers"] = [
1184
  (k, v) for k, v in request.scope["headers"]
1185
  if k.lower() not in (b"content-length",)
@@ -1192,37 +1170,40 @@ async def create_response(
1192
 
1193
  response = await generate_text(sub_request, authorization, x_client_id)
1194
 
 
1195
  if chat_body.get("stream"):
1196
  response_id = f"resp_{uuid.uuid4().hex[:24]}"
1197
  model_label = MODEL_MAP.get("lightning", "lightning")
1198
 
 
 
 
 
1199
  async def responses_stream():
1200
  sent_created = False
1201
  completed = False
1202
  buffer = ""
1203
-
1204
- async for chunk in response.body_iterator:
1205
  if isinstance(chunk, bytes):
1206
  chunk = chunk.decode("utf-8", errors="replace")
1207
  buffer += chunk
1208
-
1209
  while "\n" in buffer:
1210
  line, buffer = buffer.split("\n", 1)
1211
  line = line.rstrip("\r")
1212
  if not line:
1213
  continue
1214
-
1215
  translated, new_sent, is_done = _wrap_responses_stream_chunk(
1216
  line, response_id, model_label, sent_created
1217
  )
1218
  sent_created = new_sent
1219
-
1220
  if is_done:
1221
  completed = True
1222
-
1223
  if translated:
1224
  yield translated
1225
-
1226
  if buffer.strip():
1227
  translated, new_sent, is_done = _wrap_responses_stream_chunk(
1228
  buffer.strip(), response_id, model_label, sent_created
@@ -1232,11 +1213,8 @@ async def create_response(
1232
  completed = True
1233
  if translated:
1234
  yield translated
1235
-
1236
- print(f"[RESPONSES STREAM] generator exhausted. sent_created={sent_created} completed={completed}")
1237
-
1238
  if sent_created and not completed:
1239
- print("[RESPONSES STREAM] guard firing — upstream closed without [DONE]")
1240
  done_event = json.dumps({
1241
  "type": "response.completed",
1242
  "response": {
@@ -1260,6 +1238,7 @@ async def create_response(
1260
  },
1261
  )
1262
 
 
1263
  if hasattr(response, "body"):
1264
  raw = response.body
1265
  else:
 
789
 
790
  yield chunk
791
 
792
+ holder = request.scope.get("_stream_holder")
793
+ if holder is not None:
794
+ holder["generator"] = event_generator
795
+
796
  return StreamingResponse(
797
  event_generator(),
798
  media_type="text/event-stream",
 
949
 
950
  import time
951
  import uuid
 
952
  def _normalize_responses_input(input_field) -> list[dict]:
 
 
 
 
 
 
 
953
  if isinstance(input_field, str):
954
  return [{"role": "user", "content": input_field}]
955
 
 
978
 
979
 
980
  def _wrap_responses_output(chat_payload: dict, model_name: str) -> dict:
 
 
 
981
  choices = chat_payload.get("choices", [])
982
  output = []
983
 
 
1026
  def _wrap_responses_stream_chunk(
1027
  line: str, response_id: str, model_name: str, sent_created: bool
1028
  ) -> tuple[str, bool, bool]:
 
 
 
 
1029
  if not line.startswith("data:"):
1030
  return line + "\n", sent_created, False
1031
 
 
1130
  authorization: Optional[str] = Header(None),
1131
  x_client_id: Optional[str] = Header(None),
1132
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1133
  body = await request.json()
1134
 
1135
  input_field = body.get("input")
 
1151
 
1152
  raw_body = json.dumps(chat_body).encode()
1153
 
1154
+ # shared dict passed by reference through scope — generate_text writes
1155
+ # the generator factory here, create_response reads it back after the await
1156
+ stream_holder = {}
1157
+
1158
  scope = dict(request.scope)
1159
  scope["path"] = "/gen/chat/completions"
1160
+ scope["_stream_holder"] = stream_holder
1161
  scope["headers"] = [
1162
  (k, v) for k, v in request.scope["headers"]
1163
  if k.lower() not in (b"content-length",)
 
1170
 
1171
  response = await generate_text(sub_request, authorization, x_client_id)
1172
 
1173
+ # --- streaming path ---
1174
  if chat_body.get("stream"):
1175
  response_id = f"resp_{uuid.uuid4().hex[:24]}"
1176
  model_label = MODEL_MAP.get("lightning", "lightning")
1177
 
1178
+ raw_generator = stream_holder.get("generator")
1179
+ if raw_generator is None:
1180
+ raise HTTPException(500, "Stream generator not captured")
1181
+
1182
  async def responses_stream():
1183
  sent_created = False
1184
  completed = False
1185
  buffer = ""
1186
+
1187
+ async for chunk in raw_generator():
1188
  if isinstance(chunk, bytes):
1189
  chunk = chunk.decode("utf-8", errors="replace")
1190
  buffer += chunk
1191
+
1192
  while "\n" in buffer:
1193
  line, buffer = buffer.split("\n", 1)
1194
  line = line.rstrip("\r")
1195
  if not line:
1196
  continue
1197
+
1198
  translated, new_sent, is_done = _wrap_responses_stream_chunk(
1199
  line, response_id, model_label, sent_created
1200
  )
1201
  sent_created = new_sent
 
1202
  if is_done:
1203
  completed = True
 
1204
  if translated:
1205
  yield translated
1206
+
1207
  if buffer.strip():
1208
  translated, new_sent, is_done = _wrap_responses_stream_chunk(
1209
  buffer.strip(), response_id, model_label, sent_created
 
1213
  completed = True
1214
  if translated:
1215
  yield translated
1216
+
 
 
1217
  if sent_created and not completed:
 
1218
  done_event = json.dumps({
1219
  "type": "response.completed",
1220
  "response": {
 
1238
  },
1239
  )
1240
 
1241
+ # --- non-streaming path ---
1242
  if hasattr(response, "body"):
1243
  raw = response.body
1244
  else: