Spaces:
Running
Running
Update gen.py
Browse files
gen.py
CHANGED
|
@@ -789,6 +789,10 @@ async def generate_text(
|
|
| 789 |
|
| 790 |
yield chunk
|
| 791 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 792 |
return StreamingResponse(
|
| 793 |
event_generator(),
|
| 794 |
media_type="text/event-stream",
|
|
@@ -945,15 +949,7 @@ def return_models_openai():
|
|
| 945 |
|
| 946 |
import time
|
| 947 |
import uuid
|
| 948 |
-
|
| 949 |
def _normalize_responses_input(input_field) -> list[dict]:
|
| 950 |
-
"""
|
| 951 |
-
Coerce the Responses API `input` field into a standard messages[] list.
|
| 952 |
-
|
| 953 |
-
Accepted shapes:
|
| 954 |
-
• str → [{"role":"user","content":"..."}]
|
| 955 |
-
• list of message-like dicts → pass through, normalising content parts
|
| 956 |
-
"""
|
| 957 |
if isinstance(input_field, str):
|
| 958 |
return [{"role": "user", "content": input_field}]
|
| 959 |
|
|
@@ -982,9 +978,6 @@ def _normalize_responses_input(input_field) -> list[dict]:
|
|
| 982 |
|
| 983 |
|
| 984 |
def _wrap_responses_output(chat_payload: dict, model_name: str) -> dict:
|
| 985 |
-
"""
|
| 986 |
-
Wrap a standard chat-completions JSON response into the Responses API shape.
|
| 987 |
-
"""
|
| 988 |
choices = chat_payload.get("choices", [])
|
| 989 |
output = []
|
| 990 |
|
|
@@ -1033,10 +1026,6 @@ def _wrap_responses_output(chat_payload: dict, model_name: str) -> dict:
|
|
| 1033 |
def _wrap_responses_stream_chunk(
|
| 1034 |
line: str, response_id: str, model_name: str, sent_created: bool
|
| 1035 |
) -> tuple[str, bool, bool]:
|
| 1036 |
-
"""
|
| 1037 |
-
Returns (translated_sse_string, sent_created, is_done).
|
| 1038 |
-
is_done=True means the upstream [DONE] was processed and response.completed was emitted.
|
| 1039 |
-
"""
|
| 1040 |
if not line.startswith("data:"):
|
| 1041 |
return line + "\n", sent_created, False
|
| 1042 |
|
|
@@ -1141,22 +1130,6 @@ async def create_response(
|
|
| 1141 |
authorization: Optional[str] = Header(None),
|
| 1142 |
x_client_id: Optional[str] = Header(None),
|
| 1143 |
):
|
| 1144 |
-
"""
|
| 1145 |
-
OpenAI Responses API-compatible endpoint.
|
| 1146 |
-
|
| 1147 |
-
Accepts the Responses API request shape, normalises it into the chat
|
| 1148 |
-
completions format, routes it through the existing generate_text logic,
|
| 1149 |
-
and wraps the result back into the Responses API shape.
|
| 1150 |
-
|
| 1151 |
-
Supported fields:
|
| 1152 |
-
• input (str | list) — required
|
| 1153 |
-
• model (str) — accepted but ignored (router decides)
|
| 1154 |
-
• stream (bool) — optional, default False
|
| 1155 |
-
• tools (list) — optional, forwarded as-is
|
| 1156 |
-
• tool_choice (str|dict) — optional, forwarded as-is
|
| 1157 |
-
• temperature (float) — optional, forwarded
|
| 1158 |
-
• max_output_tokens (int) — mapped to max_tokens
|
| 1159 |
-
"""
|
| 1160 |
body = await request.json()
|
| 1161 |
|
| 1162 |
input_field = body.get("input")
|
|
@@ -1178,8 +1151,13 @@ async def create_response(
|
|
| 1178 |
|
| 1179 |
raw_body = json.dumps(chat_body).encode()
|
| 1180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1181 |
scope = dict(request.scope)
|
| 1182 |
scope["path"] = "/gen/chat/completions"
|
|
|
|
| 1183 |
scope["headers"] = [
|
| 1184 |
(k, v) for k, v in request.scope["headers"]
|
| 1185 |
if k.lower() not in (b"content-length",)
|
|
@@ -1192,37 +1170,40 @@ async def create_response(
|
|
| 1192 |
|
| 1193 |
response = await generate_text(sub_request, authorization, x_client_id)
|
| 1194 |
|
|
|
|
| 1195 |
if chat_body.get("stream"):
|
| 1196 |
response_id = f"resp_{uuid.uuid4().hex[:24]}"
|
| 1197 |
model_label = MODEL_MAP.get("lightning", "lightning")
|
| 1198 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1199 |
async def responses_stream():
|
| 1200 |
sent_created = False
|
| 1201 |
completed = False
|
| 1202 |
buffer = ""
|
| 1203 |
-
|
| 1204 |
-
async for chunk in
|
| 1205 |
if isinstance(chunk, bytes):
|
| 1206 |
chunk = chunk.decode("utf-8", errors="replace")
|
| 1207 |
buffer += chunk
|
| 1208 |
-
|
| 1209 |
while "\n" in buffer:
|
| 1210 |
line, buffer = buffer.split("\n", 1)
|
| 1211 |
line = line.rstrip("\r")
|
| 1212 |
if not line:
|
| 1213 |
continue
|
| 1214 |
-
|
| 1215 |
translated, new_sent, is_done = _wrap_responses_stream_chunk(
|
| 1216 |
line, response_id, model_label, sent_created
|
| 1217 |
)
|
| 1218 |
sent_created = new_sent
|
| 1219 |
-
|
| 1220 |
if is_done:
|
| 1221 |
completed = True
|
| 1222 |
-
|
| 1223 |
if translated:
|
| 1224 |
yield translated
|
| 1225 |
-
|
| 1226 |
if buffer.strip():
|
| 1227 |
translated, new_sent, is_done = _wrap_responses_stream_chunk(
|
| 1228 |
buffer.strip(), response_id, model_label, sent_created
|
|
@@ -1232,11 +1213,8 @@ async def create_response(
|
|
| 1232 |
completed = True
|
| 1233 |
if translated:
|
| 1234 |
yield translated
|
| 1235 |
-
|
| 1236 |
-
print(f"[RESPONSES STREAM] generator exhausted. sent_created={sent_created} completed={completed}")
|
| 1237 |
-
|
| 1238 |
if sent_created and not completed:
|
| 1239 |
-
print("[RESPONSES STREAM] guard firing — upstream closed without [DONE]")
|
| 1240 |
done_event = json.dumps({
|
| 1241 |
"type": "response.completed",
|
| 1242 |
"response": {
|
|
@@ -1260,6 +1238,7 @@ async def create_response(
|
|
| 1260 |
},
|
| 1261 |
)
|
| 1262 |
|
|
|
|
| 1263 |
if hasattr(response, "body"):
|
| 1264 |
raw = response.body
|
| 1265 |
else:
|
|
|
|
| 789 |
|
| 790 |
yield chunk
|
| 791 |
|
| 792 |
+
holder = request.scope.get("_stream_holder")
|
| 793 |
+
if holder is not None:
|
| 794 |
+
holder["generator"] = event_generator
|
| 795 |
+
|
| 796 |
return StreamingResponse(
|
| 797 |
event_generator(),
|
| 798 |
media_type="text/event-stream",
|
|
|
|
| 949 |
|
| 950 |
import time
|
| 951 |
import uuid
|
|
|
|
| 952 |
def _normalize_responses_input(input_field) -> list[dict]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 953 |
if isinstance(input_field, str):
|
| 954 |
return [{"role": "user", "content": input_field}]
|
| 955 |
|
|
|
|
| 978 |
|
| 979 |
|
| 980 |
def _wrap_responses_output(chat_payload: dict, model_name: str) -> dict:
|
|
|
|
|
|
|
|
|
|
| 981 |
choices = chat_payload.get("choices", [])
|
| 982 |
output = []
|
| 983 |
|
|
|
|
| 1026 |
def _wrap_responses_stream_chunk(
|
| 1027 |
line: str, response_id: str, model_name: str, sent_created: bool
|
| 1028 |
) -> tuple[str, bool, bool]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1029 |
if not line.startswith("data:"):
|
| 1030 |
return line + "\n", sent_created, False
|
| 1031 |
|
|
|
|
| 1130 |
authorization: Optional[str] = Header(None),
|
| 1131 |
x_client_id: Optional[str] = Header(None),
|
| 1132 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1133 |
body = await request.json()
|
| 1134 |
|
| 1135 |
input_field = body.get("input")
|
|
|
|
| 1151 |
|
| 1152 |
raw_body = json.dumps(chat_body).encode()
|
| 1153 |
|
| 1154 |
+
# shared dict passed by reference through scope — generate_text writes
|
| 1155 |
+
# the generator factory here, create_response reads it back after the await
|
| 1156 |
+
stream_holder = {}
|
| 1157 |
+
|
| 1158 |
scope = dict(request.scope)
|
| 1159 |
scope["path"] = "/gen/chat/completions"
|
| 1160 |
+
scope["_stream_holder"] = stream_holder
|
| 1161 |
scope["headers"] = [
|
| 1162 |
(k, v) for k, v in request.scope["headers"]
|
| 1163 |
if k.lower() not in (b"content-length",)
|
|
|
|
| 1170 |
|
| 1171 |
response = await generate_text(sub_request, authorization, x_client_id)
|
| 1172 |
|
| 1173 |
+
# --- streaming path ---
|
| 1174 |
if chat_body.get("stream"):
|
| 1175 |
response_id = f"resp_{uuid.uuid4().hex[:24]}"
|
| 1176 |
model_label = MODEL_MAP.get("lightning", "lightning")
|
| 1177 |
|
| 1178 |
+
raw_generator = stream_holder.get("generator")
|
| 1179 |
+
if raw_generator is None:
|
| 1180 |
+
raise HTTPException(500, "Stream generator not captured")
|
| 1181 |
+
|
| 1182 |
async def responses_stream():
|
| 1183 |
sent_created = False
|
| 1184 |
completed = False
|
| 1185 |
buffer = ""
|
| 1186 |
+
|
| 1187 |
+
async for chunk in raw_generator():
|
| 1188 |
if isinstance(chunk, bytes):
|
| 1189 |
chunk = chunk.decode("utf-8", errors="replace")
|
| 1190 |
buffer += chunk
|
| 1191 |
+
|
| 1192 |
while "\n" in buffer:
|
| 1193 |
line, buffer = buffer.split("\n", 1)
|
| 1194 |
line = line.rstrip("\r")
|
| 1195 |
if not line:
|
| 1196 |
continue
|
| 1197 |
+
|
| 1198 |
translated, new_sent, is_done = _wrap_responses_stream_chunk(
|
| 1199 |
line, response_id, model_label, sent_created
|
| 1200 |
)
|
| 1201 |
sent_created = new_sent
|
|
|
|
| 1202 |
if is_done:
|
| 1203 |
completed = True
|
|
|
|
| 1204 |
if translated:
|
| 1205 |
yield translated
|
| 1206 |
+
|
| 1207 |
if buffer.strip():
|
| 1208 |
translated, new_sent, is_done = _wrap_responses_stream_chunk(
|
| 1209 |
buffer.strip(), response_id, model_label, sent_created
|
|
|
|
| 1213 |
completed = True
|
| 1214 |
if translated:
|
| 1215 |
yield translated
|
| 1216 |
+
|
|
|
|
|
|
|
| 1217 |
if sent_created and not completed:
|
|
|
|
| 1218 |
done_event = json.dumps({
|
| 1219 |
"type": "response.completed",
|
| 1220 |
"response": {
|
|
|
|
| 1238 |
},
|
| 1239 |
)
|
| 1240 |
|
| 1241 |
+
# --- non-streaming path ---
|
| 1242 |
if hasattr(response, "body"):
|
| 1243 |
raw = response.body
|
| 1244 |
else:
|